From d591c819b0e090c3a8b2e5394b1c7175f32fe4b2 Mon Sep 17 00:00:00 2001 From: Travis Wilson <35748617+trrwilson@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:53:15 -0700 Subject: [PATCH 01/19] DRAFT: unvetted proposal for flattened streaming --- .../AzureChatExtensionsMessageContext.cs | 33 ++++ .../src/Custom/OpenAIClient.cs | 100 +++++++++++ .../src/Custom/StreamingChatCompletions2.cs | 78 +++++++++ ...mingChatCompletionsUpdate.Serialization.cs | 164 ++++++++++++++++++ .../Custom/StreamingChatCompletionsUpdate.cs | 59 +++++++ .../tests/OpenAIInferenceTests.cs | 73 ++++++++ 6 files changed, 507 insertions(+) create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs new file mode 100644 index 0000000000000..1e903ab1326ba --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using System.Collections.Generic; +using Azure.Core; + +namespace Azure.AI.OpenAI +{ + /// + /// A representation of the additional context information available when Azure OpenAI chat extensions are involved + /// in the generation of a corresponding chat completions response. This context information is only populated when + /// using an Azure OpenAI request configured to use a matching extension. + /// + public partial class AzureChatExtensionsMessageContext + { + public ContentFilterResults RequestContentFilterResults { get; internal set; } + public ContentFilterResults ResponseContentFilterResults { get; internal set; } + + internal AzureChatExtensionsMessageContext( + IList messages, + ContentFilterResults requestContentFilterResults, + ContentFilterResults responseContentFilterResults) + : this(messages) + { + RequestContentFilterResults = requestContentFilterResults; + ResponseContentFilterResults = responseContentFilterResults; + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs index e46691d42ae88..d55ad4462a585 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs @@ -500,6 +500,106 @@ public virtual async Task> GetChatCompletions } } + /// + /// Begin a chat completions request and get an object that can stream response data as it becomes + /// available. + /// + /// + /// + /// + /// + /// the chat completions options for this chat completions request. + /// + /// + /// a cancellation token that can be used to cancel the initial request or ongoing streaming operation. + /// + /// + /// or is null. + /// + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual Response GetChatCompletionsStreaming2( + string deploymentOrModelName, + ChatCompletionsOptions chatCompletionsOptions, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); + Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); + + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2"); + scope.Start(); + + chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; + chatCompletionsOptions.InternalShouldStreamResponse = true; + + string operationPath = GetOperationPath(chatCompletionsOptions); + + RequestContent content = chatCompletionsOptions.ToRequestContent(); + RequestContext context = FromCancellationToken(cancellationToken); + + try + { + // Response value object takes IDisposable ownership of message + HttpMessage message = CreatePostRequestMessage( + deploymentOrModelName, + operationPath, + content, + context); + message.BufferResponse = false; + Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); + return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + + /// + public virtual async Task> GetChatCompletionsStreamingAsync2( + string deploymentOrModelName, + ChatCompletionsOptions chatCompletionsOptions, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); + Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); + + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2"); + scope.Start(); + + chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; + chatCompletionsOptions.InternalShouldStreamResponse = true; + + string operationPath = GetOperationPath(chatCompletionsOptions); + + RequestContent content = chatCompletionsOptions.ToRequestContent(); + RequestContext context = FromCancellationToken(cancellationToken); + + try + { + // Response value object takes IDisposable ownership of message + HttpMessage message = CreatePostRequestMessage( + deploymentOrModelName, + operationPath, + content, + context); + message.BufferResponse = false; + Response baseResponse = await _pipeline.ProcessMessageAsync( + message, + context, + cancellationToken).ConfigureAwait(false); + return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse); + } + catch (Exception e) + { + scope.Failed(e); + throw; + } + } + /// Return the computed embeddings for a given prompt. /// /// EnumerateChatUpdates( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (true) + { + SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); + if (sseEvent == null) + { + _baseResponse.ContentStream?.Dispose(); + break; + } + + ReadOnlyMemory name = sseEvent.Value.FieldName; + if (!name.Span.SequenceEqual("data".AsSpan())) + throw new InvalidDataException(); + + ReadOnlyMemory value = sseEvent.Value.FieldValue; + if (value.Span.SequenceEqual("[DONE]".AsSpan())) + { + _baseResponse.ContentStream?.Dispose(); + break; + } + + JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); + foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate + in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement)) + { + yield return streamingChatCompletionsUpdate; + } + } + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _baseResponseReader.Dispose(); + } + + _disposedValue = true; + } + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs new file mode 100644 index 0000000000000..4d11ca2a15229 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using Azure.Core; + +namespace Azure.AI.OpenAI +{ + public partial class StreamingChatCompletionsUpdate + { + internal static List DeserializeStreamingChatCompletionsUpdates(JsonElement element) + { + var results = new List(); + if (element.ValueKind == JsonValueKind.Null) + { + return results; + } + string id = default; + DateTimeOffset created = default; + AzureChatExtensionsMessageContext azureExtensionsContext = null; + ContentFilterResults requestContentFilterResults = null; + foreach (JsonProperty property in element.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("created"u8)) + { + created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + // CUSTOM CODE NOTE: temporary, custom handling of forked keys for prompt filter results + if (property.NameEquals("prompt_annotations"u8) || property.NameEquals("prompt_filter_results")) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + List array = new List(); + foreach (var item in property.Value.EnumerateArray()) + { + PromptFilterResult promptFilterResult = PromptFilterResult.DeserializePromptFilterResult(item); + requestContentFilterResults = promptFilterResult.ContentFilterResults; + } + continue; + } + if (property.NameEquals("choices"u8)) + { + foreach (JsonElement choiceElement in property.Value.EnumerateArray()) + { + ChatRole? role = null; + string contentUpdate = null; + string functionName = null; + string authorName = null; + string functionArgumentsUpdate = null; + int choiceIndex = 0; + CompletionsFinishReason? finishReason = null; + ContentFilterResults responseContentFilterResults = null; + foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) + { + if (choiceProperty.NameEquals("index"u8)) + { + choiceIndex = choiceProperty.Value.GetInt32(); + continue; + } + if (choiceProperty.NameEquals("finish_reason"u8)) + { + finishReason = new CompletionsFinishReason(choiceProperty.Value.GetString()); + continue; + } + if (choiceProperty.NameEquals("delta"u8)) + { + foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) + { + if (deltaProperty.NameEquals("role"u8)) + { + role = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("name"u8)) + { + authorName = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("content"u8)) + { + contentUpdate = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("function_call"u8)) + { + foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) + { + if (functionProperty.NameEquals("name"u8)) + { + functionName = functionProperty.Value.GetString(); + continue; + } + if (functionProperty.NameEquals("arguments"u8)) + { + functionArgumentsUpdate = functionProperty.Value.GetString(); + } + } + } + if (deltaProperty.NameEquals("context"u8)) + { + azureExtensionsContext = AzureChatExtensionsMessageContext + .DeserializeAzureChatExtensionsMessageContext(deltaProperty.Value); + continue; + } + } + } + if (choiceProperty.NameEquals("content_filter_results"u8)) + { + if (choiceProperty.Value.EnumerateObject().Any()) + { + responseContentFilterResults = ContentFilterResults.DeserializeContentFilterResults(choiceProperty.Value); + } + continue; + } + } + if (requestContentFilterResults is not null || responseContentFilterResults is not null) + { + azureExtensionsContext ??= new AzureChatExtensionsMessageContext(); + azureExtensionsContext.RequestContentFilterResults = requestContentFilterResults; + azureExtensionsContext.ResponseContentFilterResults = responseContentFilterResults; + } + results.Add(new StreamingChatCompletionsUpdate( + id, + created, + choiceIndex, + role, + authorName, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + azureExtensionsContext)); + } + continue; + } + } + if (results.Count == 0) + { + if (requestContentFilterResults is not null) + { + azureExtensionsContext ??= new AzureChatExtensionsMessageContext() + { + RequestContentFilterResults = requestContentFilterResults, + }; + } + results.Add(new StreamingChatCompletionsUpdate(id, created, azureExtensionsContext: azureExtensionsContext)); + } + return results; + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs new file mode 100644 index 0000000000000..c16e0d20f2e5d --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using Azure.Core; + +namespace Azure.AI.OpenAI +{ + public partial class StreamingChatCompletionsUpdate + { + public string Id { get; } + + public DateTimeOffset Created { get; } + + public ChatRole? Role { get; } + + public string ContentUpdate { get; } + + public string FunctionName { get; } + + public string FunctionArgumentsUpdate { get; } + + public string AuthorName { get; } + + public CompletionsFinishReason? FinishReason { get; } + + public int? ChoiceIndex { get; } + + public AzureChatExtensionsMessageContext AzureExtensionsContext { get; } + + internal StreamingChatCompletionsUpdate( + string id, + DateTimeOffset created, + int? choiceIndex = null, + ChatRole? role = null, + string authorName = null, + string contentUpdate = null, + CompletionsFinishReason? finishReason = null, + string functionName = null, + string functionArgumentsUpdate = null, + AzureChatExtensionsMessageContext azureExtensionsContext = null) + { + Id = id; + Created = created; + ChoiceIndex = choiceIndex; + Role = role; + AuthorName = authorName; + ContentUpdate = contentUpdate; + FinishReason = finishReason; + FunctionName = functionName; + FunctionArgumentsUpdate = functionArgumentsUpdate; + AzureExtensionsContext = azureExtensionsContext; + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index f7eaaf696bc35..7376674572ee8 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -273,6 +273,79 @@ Response streamingResponse AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1)); } + [RecordedTest] + [TestCase(OpenAIClientServiceTarget.Azure)] + [TestCase(OpenAIClientServiceTarget.NonAzure)] + public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTarget) + { + OpenAIClient client = GetTestClient(serviceTarget); + string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( + serviceTarget, + OpenAIClientScenario.ChatCompletions); + var requestOptions = new ChatCompletionsOptions() + { + Messages = + { + new ChatMessage(ChatRole.System, "You are a helpful assistant."), + new ChatMessage(ChatRole.User, "Can you help me?"), + new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"), + new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"), + }, + MaxTokens = 512, + }; + Response streamingResponse + = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions); + Assert.That(streamingResponse, Is.Not.Null); + using StreamingChatCompletions2 streamingChatCompletions = streamingResponse.Value; + Assert.That(streamingChatCompletions, Is.InstanceOf()); + + StringBuilder contentBuilder = new StringBuilder(); + bool gotRole = false; + bool gotRequestContentFilterResults = false; + bool gotResponseContentFilterResults = false; + + await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates()) + { + Assert.That(chatUpdate, Is.Not.Null); + + if (chatUpdate.AzureExtensionsContext?.RequestContentFilterResults is null) + { + Assert.That(chatUpdate.Id, Is.Not.Null.Or.Empty); + Assert.That(chatUpdate.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1)))); + Assert.That(chatUpdate.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7))); + } + if (chatUpdate.Role.HasValue) + { + Assert.IsFalse(gotRole); + Assert.That(chatUpdate.Role.Value, Is.EqualTo(ChatRole.Assistant)); + gotRole = true; + } + if (chatUpdate.ContentUpdate is not null) + { + contentBuilder.Append(chatUpdate.ContentUpdate); + } + if (chatUpdate.AzureExtensionsContext?.RequestContentFilterResults is not null) + { + Assert.IsFalse(gotRequestContentFilterResults); + AssertExpectedContentFilterResults(chatUpdate.AzureExtensionsContext.RequestContentFilterResults, serviceTarget); + gotRequestContentFilterResults = true; + } + if (chatUpdate.AzureExtensionsContext?.ResponseContentFilterResults is not null) + { + AssertExpectedContentFilterResults(chatUpdate.AzureExtensionsContext.ResponseContentFilterResults, serviceTarget); + gotResponseContentFilterResults = true; + } + } + + Assert.IsTrue(gotRole); + Assert.That(contentBuilder.ToString(), Is.Not.Null.Or.Empty); + if (serviceTarget == OpenAIClientServiceTarget.Azure) + { + Assert.IsTrue(gotRequestContentFilterResults); + Assert.IsTrue(gotResponseContentFilterResults); + } + } + [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] [TestCase(OpenAIClientServiceTarget.NonAzure)] From 6dfd14a8b19960021f448193643d7f7fca1e12b7 Mon Sep 17 00:00:00 2001 From: Travis Wilson <35748617+trrwilson@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:45:36 -0700 Subject: [PATCH 02/19] add ported functions test --- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 25 +++++++ .../src/Custom/StreamingChatCompletions2.cs | 9 ++- .../tests/ChatFunctionsTests.cs | 68 +++++++++++++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index a4b3f10dcf7df..8fd561d1c7c8d 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -114,6 +114,8 @@ public partial class AzureChatExtensionsMessageContext { public AzureChatExtensionsMessageContext() { } public System.Collections.Generic.IList Messages { get { throw null; } } + public Azure.AI.OpenAI.ContentFilterResults RequestContentFilterResults { get { throw null; } } + public Azure.AI.OpenAI.ContentFilterResults ResponseContentFilterResults { get { throw null; } } } public partial class AzureChatExtensionsOptions { @@ -479,7 +481,9 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.Response GetChatCompletionsStreaming2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -528,6 +532,27 @@ public void Dispose() { } protected virtual void Dispose(bool disposing) { } public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } + public partial class StreamingChatCompletions2 : System.IDisposable + { + internal StreamingChatCompletions2() { } + public void Dispose() { } + protected virtual void Dispose(bool disposing) { } + public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + } + public partial class StreamingChatCompletionsUpdate + { + internal StreamingChatCompletionsUpdate() { } + public string AuthorName { get { throw null; } } + public Azure.AI.OpenAI.AzureChatExtensionsMessageContext AzureExtensionsContext { get { throw null; } } + public int? ChoiceIndex { get { throw null; } } + public string ContentUpdate { get { throw null; } } + public System.DateTimeOffset Created { get { throw null; } } + public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } } + public string FunctionArgumentsUpdate { get { throw null; } } + public string FunctionName { get { throw null; } } + public string Id { get { throw null; } } + public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } } + } public partial class StreamingChoice { internal StreamingChoice() { } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs index fb16e591e6766..967dea227f2a7 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs @@ -27,7 +27,13 @@ internal StreamingChatCompletions2(Response response) public async IAsyncEnumerable EnumerateChatUpdates( [EnumeratorCancellation] CancellationToken cancellationToken = default) { - while (true) + // Open clarification points: + // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? + // - What should happen if the stream is held too long and it times out? + // - Should enumeration be repeatable, e.g. via locally caching a collection of the updates? + // - What if the initial enumeration was cancelled? + // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? + while (!cancellationToken.IsCancellationRequested) { SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); if (sseEvent == null) @@ -54,6 +60,7 @@ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sse yield return streamingChatCompletionsUpdate; } } + _baseResponse?.ContentStream?.Dispose(); } public void Dispose() diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs index 36b612ce299dd..57be0f978564a 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs @@ -135,6 +135,55 @@ Response response Assert.That(argumentsBuilder.Length, Is.GreaterThan(0)); } + [RecordedTest] + [TestCase(OpenAIClientServiceTarget.Azure)] + [TestCase(OpenAIClientServiceTarget.NonAzure)] + public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceTarget) + { + OpenAIClient client = GetTestClient(serviceTarget); + string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions); + + var requestOptions = new ChatCompletionsOptions() + { + Functions = { s_futureTemperatureFunction, s_wordOfTheDayFunction }, + Messages = + { + new ChatMessage(ChatRole.System, "You are a helpful assistant."), + new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"), + }, + MaxTokens = 512, + }; + + Response response + = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions); + Assert.That(response, Is.Not.Null); + + using StreamingChatCompletions2 streamingChatCompletions = response.Value; + + ChatRole? streamedRole = default; + string functionName = null; + StringBuilder argumentsBuilder = new(); + + await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates()) + { + if (chatUpdate.Role.HasValue) + { + Assert.That(streamedRole, Is.Null, "role should only appear once"); + streamedRole = chatUpdate.Role.Value; + } + if (!string.IsNullOrEmpty(chatUpdate.FunctionName)) + { + Assert.That(functionName, Is.Null, "function_name should only appear once"); + functionName = chatUpdate.FunctionName; + } + argumentsBuilder.Append(chatUpdate.FunctionArgumentsUpdate); + } + + Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant)); + Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name)); + Assert.That(argumentsBuilder.Length, Is.GreaterThan(0)); + } + private static readonly FunctionDefinition s_futureTemperatureFunction = new() { Name = "get_future_temperature", @@ -159,4 +208,23 @@ Response response }, new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }) }; + + private static readonly FunctionDefinition s_wordOfTheDayFunction = new() + { + Name = "get_word_of_the_day", + Description = "requests a featured word for a given day of the week", + Parameters = BinaryData.FromObjectAsJson(new + { + Type = "object", + Properties = new + { + DayOfWeek = new + { + Type = "string", + Description = "the name of a day of the week, like Wednesday", + }, + } + }, + new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }) + }; } From 4417b07dc6a19f238d794c63db85e05cfb19d7d5 Mon Sep 17 00:00:00 2001 From: Travis Wilson <35748617+trrwilson@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:17:48 -0700 Subject: [PATCH 03/19] remaining tests ported --- .../src/Custom/OpenAIClient.cs | 198 +++++++++--------- .../tests/AzureChatExtensionsTests.cs | 96 +++++++-- .../tests/ChatFunctionsTests.cs | 100 ++++----- .../tests/OpenAIInferenceTests.cs | 98 ++++----- .../tests/Samples/Sample04_StreamingChat.cs | 49 ++++- 5 files changed, 324 insertions(+), 217 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs index d55ad4462a585..6c44d50e6cc19 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs @@ -400,105 +400,105 @@ public virtual async Task> GetChatCompletionsAsync( } } - /// - /// Begin a chat completions request and get an object that can stream response data as it becomes - /// available. - /// - /// - /// - /// - /// - /// the chat completions options for this chat completions request. - /// - /// - /// a cancellation token that can be used to cancel the initial request or ongoing streaming operation. - /// - /// - /// or is null. - /// - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual Response GetChatCompletionsStreaming( - string deploymentOrModelName, - ChatCompletionsOptions chatCompletionsOptions, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); - Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - - using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); - scope.Start(); - - chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; - chatCompletionsOptions.InternalShouldStreamResponse = true; - - string operationPath = GetOperationPath(chatCompletionsOptions); - - RequestContent content = chatCompletionsOptions.ToRequestContent(); - RequestContext context = FromCancellationToken(cancellationToken); - - try - { - // Response value object takes IDisposable ownership of message - HttpMessage message = CreatePostRequestMessage( - deploymentOrModelName, - operationPath, - content, - context); - message.BufferResponse = false; - Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); - } - catch (Exception e) - { - scope.Failed(e); - throw; - } - } - - /// - public virtual async Task> GetChatCompletionsStreamingAsync( - string deploymentOrModelName, - ChatCompletionsOptions chatCompletionsOptions, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); - Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - - using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); - scope.Start(); - - chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; - chatCompletionsOptions.InternalShouldStreamResponse = true; - - string operationPath = GetOperationPath(chatCompletionsOptions); - - RequestContent content = chatCompletionsOptions.ToRequestContent(); - RequestContext context = FromCancellationToken(cancellationToken); - - try - { - // Response value object takes IDisposable ownership of message - HttpMessage message = CreatePostRequestMessage( - deploymentOrModelName, - operationPath, - content, - context); - message.BufferResponse = false; - Response baseResponse = await _pipeline.ProcessMessageAsync( - message, - context, - cancellationToken).ConfigureAwait(false); - return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); - } - catch (Exception e) - { - scope.Failed(e); - throw; - } - } + ///// + ///// Begin a chat completions request and get an object that can stream response data as it becomes + ///// available. + ///// + ///// + ///// + ///// + ///// + ///// the chat completions options for this chat completions request. + ///// + ///// + ///// a cancellation token that can be used to cancel the initial request or ongoing streaming operation. + ///// + ///// + ///// or is null. + ///// + ///// Service returned a non-success status code. + ///// The response returned from the service. + //public virtual Response GetChatCompletionsStreaming( + // string deploymentOrModelName, + // ChatCompletionsOptions chatCompletionsOptions, + // CancellationToken cancellationToken = default) + //{ + // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); + // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); + + // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); + // scope.Start(); + + // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; + // chatCompletionsOptions.InternalShouldStreamResponse = true; + + // string operationPath = GetOperationPath(chatCompletionsOptions); + + // RequestContent content = chatCompletionsOptions.ToRequestContent(); + // RequestContext context = FromCancellationToken(cancellationToken); + + // try + // { + // // Response value object takes IDisposable ownership of message + // HttpMessage message = CreatePostRequestMessage( + // deploymentOrModelName, + // operationPath, + // content, + // context); + // message.BufferResponse = false; + // Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); + // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); + // } + // catch (Exception e) + // { + // scope.Failed(e); + // throw; + // } + //} + + ///// + //public virtual async Task> GetChatCompletionsStreamingAsync( + // string deploymentOrModelName, + // ChatCompletionsOptions chatCompletionsOptions, + // CancellationToken cancellationToken = default) + //{ + // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); + // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); + + // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); + // scope.Start(); + + // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; + // chatCompletionsOptions.InternalShouldStreamResponse = true; + + // string operationPath = GetOperationPath(chatCompletionsOptions); + + // RequestContent content = chatCompletionsOptions.ToRequestContent(); + // RequestContext context = FromCancellationToken(cancellationToken); + + // try + // { + // // Response value object takes IDisposable ownership of message + // HttpMessage message = CreatePostRequestMessage( + // deploymentOrModelName, + // operationPath, + // content, + // context); + // message.BufferResponse = false; + // Response baseResponse = await _pipeline.ProcessMessageAsync( + // message, + // context, + // cancellationToken).ConfigureAwait(false); + // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); + // } + // catch (Exception e) + // { + // scope.Failed(e); + // throw; + // } + //} /// /// Begin a chat completions request and get an object that can stream response data as it becomes diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs index 7f56bb370064c..8f913da5a4dfd 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Azure.Core.TestFramework; @@ -88,6 +89,68 @@ public async Task BasicSearchExtensionWorks( Assert.That(context.Messages.First().Content.Contains("citations")); } + //[RecordedTest] + //[TestCase(OpenAIClientServiceTarget.Azure)] + //public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget) + //{ + // OpenAIClient client = GetTestClient(serviceTarget); + // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( + // serviceTarget, + // OpenAIClientScenario.ChatCompletions); + + // var requestOptions = new ChatCompletionsOptions() + // { + // Messages = + // { + // new ChatMessage(ChatRole.User, "What does PR complete mean?"), + // }, + // MaxTokens = 512, + // AzureExtensionsOptions = new() + // { + // Extensions = + // { + // new AzureChatExtensionConfiguration() + // { + // Type = "AzureCognitiveSearch", + // Parameters = BinaryData.FromObjectAsJson(new + // { + // Endpoint = "https://openaisdktestsearch.search.windows.net", + // IndexName = "openai-test-index-carbon-wiki", + // Key = GetCognitiveSearchApiKey().Key, + // }, + // new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }), + // }, + // } + // }, + // }; + + // Response response = await client.GetChatCompletionsStreamingAsync( + // deploymentOrModelName, + // requestOptions); + // Assert.That(response, Is.Not.Null); + // Assert.That(response.Value, Is.Not.Null); + + // using StreamingChatCompletions streamingChatCompletions = response.Value; + + // int choiceCount = 0; + // List messageChunks = new(); + + // await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming()) + // { + // choiceCount++; + // await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming()) + // { + // messageChunks.Add(chatMessage); + // } + // } + + // Assert.That(choiceCount, Is.EqualTo(1)); + // Assert.That(messageChunks, Is.Not.Null.Or.Empty); + // Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool))); + // //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant)); + // Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content))); + //} + [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget) @@ -123,30 +186,37 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic }, }; - Response response = await client.GetChatCompletionsStreamingAsync( + Response response = await client.GetChatCompletionsStreamingAsync2( deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); Assert.That(response.Value, Is.Not.Null); - using StreamingChatCompletions streamingChatCompletions = response.Value; + using StreamingChatCompletions2 streamingChatCompletions = response.Value; - int choiceCount = 0; - List messageChunks = new(); + ChatRole? streamedRole = null; + IEnumerable azureContextMessages = null; + StringBuilder contentBuilder = new(); - await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming()) + await foreach (StreamingChatCompletionsUpdate chatUpdate in response.Value.EnumerateChatUpdates()) { - choiceCount++; - await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming()) + if (chatUpdate.Role.HasValue) + { + Assert.That(streamedRole, Is.Null); + streamedRole = chatUpdate.Role.Value; + } + if (chatUpdate.AzureExtensionsContext?.Messages?.Count > 0) { - messageChunks.Add(chatMessage); + Assert.That(azureContextMessages, Is.Null); + azureContextMessages = chatUpdate.AzureExtensionsContext.Messages; } + contentBuilder.Append(chatUpdate.ContentUpdate); } - Assert.That(choiceCount, Is.EqualTo(1)); - Assert.That(messageChunks, Is.Not.Null.Or.Empty); - Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool))); - //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant)); - Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content))); + Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant)); + Assert.That(contentBuilder.ToString(), Is.Not.Null.Or.Empty); + Assert.That(azureContextMessages, Is.Not.Null.Or.Empty); + Assert.That(azureContextMessages.Any(contextMessage => contextMessage.Role == ChatRole.Tool)); + Assert.That(azureContextMessages.Any(contextMessage => !string.IsNullOrEmpty(contextMessage.Content))); } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs index 57be0f978564a..c7d3289706775 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs @@ -84,56 +84,56 @@ public async Task SimpleFunctionCallWorks(OpenAIClientServiceTarget serviceTarge Assert.That(followupResponse.Value.Choices[0].Message.Content, Is.Not.Null.Or.Empty); } - [RecordedTest] - [TestCase(OpenAIClientServiceTarget.Azure)] - [TestCase(OpenAIClientServiceTarget.NonAzure)] - public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget) - { - OpenAIClient client = GetTestClient(serviceTarget); - string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions); - - var requestOptions = new ChatCompletionsOptions() - { - Functions = { s_futureTemperatureFunction }, - Messages = - { - new ChatMessage(ChatRole.System, "You are a helpful assistant."), - new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"), - }, - MaxTokens = 512, - }; - - Response response - = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); - Assert.That(response, Is.Not.Null); - - using StreamingChatCompletions streamingChatCompletions = response.Value; - - ChatRole streamedRole = default; - string functionName = default; - StringBuilder argumentsBuilder = new(); - - await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) - { - await foreach (ChatMessage message in choice.GetMessageStreaming()) - { - if (message.Role != default) - { - streamedRole = message.Role; - } - if (message.FunctionCall?.Name != null) - { - Assert.That(functionName, Is.Null.Or.Empty); - functionName = message.FunctionCall.Name; - } - argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty); - } - } - - Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant)); - Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name)); - Assert.That(argumentsBuilder.Length, Is.GreaterThan(0)); - } + //[RecordedTest] + //[TestCase(OpenAIClientServiceTarget.Azure)] + //[TestCase(OpenAIClientServiceTarget.NonAzure)] + //public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget) + //{ + // OpenAIClient client = GetTestClient(serviceTarget); + // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions); + + // var requestOptions = new ChatCompletionsOptions() + // { + // Functions = { s_futureTemperatureFunction }, + // Messages = + // { + // new ChatMessage(ChatRole.System, "You are a helpful assistant."), + // new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"), + // }, + // MaxTokens = 512, + // }; + + // Response response + // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); + // Assert.That(response, Is.Not.Null); + + // using StreamingChatCompletions streamingChatCompletions = response.Value; + + // ChatRole streamedRole = default; + // string functionName = default; + // StringBuilder argumentsBuilder = new(); + + // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) + // { + // await foreach (ChatMessage message in choice.GetMessageStreaming()) + // { + // if (message.Role != default) + // { + // streamedRole = message.Role; + // } + // if (message.FunctionCall?.Name != null) + // { + // Assert.That(functionName, Is.Null.Or.Empty); + // functionName = message.FunctionCall.Name; + // } + // argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty); + // } + // } + + // Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant)); + // Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name)); + // Assert.That(argumentsBuilder.Length, Is.GreaterThan(0)); + //} [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index 7376674572ee8..5b8ba0ba9f42e 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -223,55 +223,55 @@ string deploymentOrModelName AssertExpectedContentFilterResults(firstChoice.ContentFilterResults, serviceTarget); } - [RecordedTest] - [TestCase(OpenAIClientServiceTarget.Azure)] - [TestCase(OpenAIClientServiceTarget.NonAzure)] - public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget) - { - OpenAIClient client = GetTestClient(serviceTarget); - string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( - serviceTarget, - OpenAIClientScenario.ChatCompletions); - var requestOptions = new ChatCompletionsOptions() - { - Messages = - { - new ChatMessage(ChatRole.System, "You are a helpful assistant."), - new ChatMessage(ChatRole.User, "Can you help me?"), - new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"), - new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"), - }, - MaxTokens = 512, - }; - Response streamingResponse - = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); - Assert.That(streamingResponse, Is.Not.Null); - using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value; - Assert.That(streamingChatCompletions, Is.InstanceOf()); - - int totalMessages = 0; - - await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming()) - { - Assert.That(streamingChoice, Is.Not.Null); - await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming()) - { - Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant)); - totalMessages++; - } - AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget); - } - - Assert.That(totalMessages, Is.GreaterThan(1)); - - // Note: these top-level values *are likely not yet populated* until *after* at least one streaming - // choice has arrived. - Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty); - Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty); - Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1)))); - Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7))); - AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1)); - } + //[RecordedTest] + //[TestCase(OpenAIClientServiceTarget.Azure)] + //[TestCase(OpenAIClientServiceTarget.NonAzure)] + //public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget) + //{ + // OpenAIClient client = GetTestClient(serviceTarget); + // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( + // serviceTarget, + // OpenAIClientScenario.ChatCompletions); + // var requestOptions = new ChatCompletionsOptions() + // { + // Messages = + // { + // new ChatMessage(ChatRole.System, "You are a helpful assistant."), + // new ChatMessage(ChatRole.User, "Can you help me?"), + // new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"), + // new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"), + // }, + // MaxTokens = 512, + // }; + // Response streamingResponse + // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); + // Assert.That(streamingResponse, Is.Not.Null); + // using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value; + // Assert.That(streamingChatCompletions, Is.InstanceOf()); + + // int totalMessages = 0; + + // await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming()) + // { + // Assert.That(streamingChoice, Is.Not.Null); + // await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming()) + // { + // Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant)); + // totalMessages++; + // } + // AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget); + // } + + // Assert.That(totalMessages, Is.GreaterThan(1)); + + // // Note: these top-level values *are likely not yet populated* until *after* at least one streaming + // // choice has arrived. + // Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty); + // Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty); + // Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1)))); + // Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7))); + // AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1)); + //} [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs index 21a25241f13d2..bd0ee5d1db4e1 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs @@ -11,6 +11,39 @@ namespace Azure.AI.OpenAI.Tests.Samples { public partial class StreamingChat { + //[Test] + //[Ignore("Only verifying that the sample builds")] + //public async Task StreamingChatWithNonAzureOpenAI() + //{ + // #region Snippet:StreamChatMessages + // string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com"; + // var client = new OpenAIClient(nonAzureOpenAIApiKey, new OpenAIClientOptions()); + // var chatCompletionsOptions = new ChatCompletionsOptions() + // { + // Messages = + // { + // new ChatMessage(ChatRole.System, "You are a helpful assistant. You will talk like a pirate."), + // new ChatMessage(ChatRole.User, "Can you help me?"), + // new ChatMessage(ChatRole.Assistant, "Arrrr! Of course, me hearty! What can I do for ye?"), + // new ChatMessage(ChatRole.User, "What's the best way to train a parrot?"), + // } + // }; + + // Response response = await client.GetChatCompletionsStreamingAsync( + // deploymentOrModelName: "gpt-3.5-turbo", + // chatCompletionsOptions); + // using StreamingChatCompletions streamingChatCompletions = response.Value; + + // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) + // { + // await foreach (ChatMessage message in choice.GetMessageStreaming()) + // { + // Console.Write(message.Content); + // } + // Console.WriteLine(); + // } + // #endregion + //} [Test] [Ignore("Only verifying that the sample builds")] public async Task StreamingChatWithNonAzureOpenAI() @@ -29,18 +62,22 @@ public async Task StreamingChatWithNonAzureOpenAI() } }; - Response response = await client.GetChatCompletionsStreamingAsync( + Response response = await client.GetChatCompletionsStreamingAsync2( deploymentOrModelName: "gpt-3.5-turbo", chatCompletionsOptions); - using StreamingChatCompletions streamingChatCompletions = response.Value; + using StreamingChatCompletions2 streamingChatCompletions = response.Value; - await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) + await foreach (StreamingChatCompletionsUpdate chatUpdate + in streamingChatCompletions.EnumerateChatUpdates()) { - await foreach (ChatMessage message in choice.GetMessageStreaming()) + if (chatUpdate.Role.HasValue) + { + Console.Write($"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: "); + } + if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate)) { - Console.Write(message.Content); + Console.Write(chatUpdate.ContentUpdate); } - Console.WriteLine(); } #endregion } From ac477eee70d2216a38790c2051883ce60efe11e2 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Wed, 18 Oct 2023 16:10:47 -0700 Subject: [PATCH 04/19] completions for consistency --- sdk/openai/Azure.AI.OpenAI/README.md | 12 +- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 41 +--- .../src/Custom/AzureOpenAIModelFactory.cs | 60 +++--- .../src/Custom/OpenAIClient.cs | 114 +---------- .../src/Custom/StreamingChatChoice.cs | 148 -------------- .../src/Custom/StreamingChatCompletions.cs | 184 ++++-------------- .../src/Custom/StreamingChatCompletions2.cs | 85 -------- .../src/Custom/StreamingChoice.cs | 150 -------------- .../src/Custom/StreamingCompletions.cs | 184 ++++-------------- .../tests/AzureChatExtensionsTests.cs | 4 +- .../tests/ChatFunctionsTests.cs | 6 +- .../tests/OpenAIInferenceTests.cs | 96 +++++---- .../tests/Samples/Sample04_StreamingChat.cs | 37 +--- 13 files changed, 183 insertions(+), 938 deletions(-) delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index e7f4acb1bfe11..cd8a7d002be92 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -211,13 +211,17 @@ Response response = await client.GetChatCompletionsStr chatCompletionsOptions); using StreamingChatCompletions streamingChatCompletions = response.Value; -await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) +await foreach (StreamingChatCompletionsUpdate chatUpdate + in streamingChatCompletions.EnumerateChatUpdates()) { - await foreach (ChatMessage message in choice.GetMessageStreaming()) + if (chatUpdate.Role.HasValue) { - Console.Write(message.Content); + Console.Write($"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: "); + } + if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate)) + { + Console.Write(chatUpdate.ContentUpdate); } - Console.WriteLine(); } ``` diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index 8fd561d1c7c8d..7a2dfac1a77bc 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -206,10 +206,9 @@ public static partial class AzureOpenAIModelFactory public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable data = null) { throw null; } public static Azure.AI.OpenAI.ImageLocation ImageLocation(System.Uri url = null) { throw null; } public static Azure.AI.OpenAI.PromptFilterResult PromptFilterResult(int promptIndex = 0, Azure.AI.OpenAI.ContentFilterResults contentFilterResults = null) { throw null; } - public static Azure.AI.OpenAI.StreamingChatChoice StreamingChatChoice(Azure.AI.OpenAI.ChatChoice originalBaseChoice = null) { throw null; } - public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(Azure.AI.OpenAI.ChatCompletions baseChatCompletions = null, System.Collections.Generic.List streamingChatChoices = null) { throw null; } - public static Azure.AI.OpenAI.StreamingChoice StreamingChoice(Azure.AI.OpenAI.Choice originalBaseChoice = null) { throw null; } - public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(Azure.AI.OpenAI.Completions baseCompletions = null, System.Collections.Generic.List streamingChoices = null) { throw null; } + public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(System.Collections.Generic.IEnumerable updates) { throw null; } + public static Azure.AI.OpenAI.StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(string id, System.DateTimeOffset created, int? choiceIndex = default(int?), Azure.AI.OpenAI.ChatRole? role = default(Azure.AI.OpenAI.ChatRole?), string authorName = null, string contentUpdate = null, Azure.AI.OpenAI.CompletionsFinishReason? finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason?), string functionName = null, string functionArgumentsUpdate = null, Azure.AI.OpenAI.AzureChatExtensionsMessageContext azureExtensionsContext = null) { throw null; } + public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(System.Collections.Generic.IEnumerable completions) { throw null; } } public partial class ChatChoice { @@ -481,9 +480,7 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.Response GetChatCompletionsStreaming2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -514,27 +511,9 @@ internal PromptFilterResult() { } public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } } public int PromptIndex { get { throw null; } } } - public partial class StreamingChatChoice - { - internal StreamingChatChoice() { } - public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } } - public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } } - public int? Index { get { throw null; } } - public System.Collections.Generic.IAsyncEnumerable GetMessageStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - } public partial class StreamingChatCompletions : System.IDisposable { internal StreamingChatCompletions() { } - public System.DateTimeOffset Created { get { throw null; } } - public string Id { get { throw null; } } - public System.Collections.Generic.IReadOnlyList PromptFilterResults { get { throw null; } } - public void Dispose() { } - protected virtual void Dispose(bool disposing) { } - public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - } - public partial class StreamingChatCompletions2 : System.IDisposable - { - internal StreamingChatCompletions2() { } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -553,24 +532,12 @@ internal StreamingChatCompletionsUpdate() { } public string Id { get { throw null; } } public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } } } - public partial class StreamingChoice - { - internal StreamingChoice() { } - public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } } - public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } } - public int? Index { get { throw null; } } - public Azure.AI.OpenAI.CompletionsLogProbabilityModel LogProbabilityModel { get { throw null; } } - public System.Collections.Generic.IAsyncEnumerable GetTextStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - } public partial class StreamingCompletions : System.IDisposable { internal StreamingCompletions() { } - public System.DateTimeOffset Created { get { throw null; } } - public string Id { get { throw null; } } - public System.Collections.Generic.IReadOnlyList PromptFilterResults { get { throw null; } } public void Dispose() { } protected virtual void Dispose(bool disposing) { } - public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Collections.Generic.IAsyncEnumerable EnumerateCompletions([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } } namespace Microsoft.Extensions.Azure diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs index a879727738372..cc90ea19685df 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs @@ -30,50 +30,52 @@ public static ChatChoice ChatChoice( return new ChatChoice(message, index, finishReason, deltaMessage, contentFilterResults); } - /// - /// Initializes a new instance of StreamingChoice for tests and mocking. - /// - /// An underlying Choice for this streaming representation - /// A new instance of StreamingChoice - public static StreamingChoice StreamingChoice(Choice originalBaseChoice = null) - { - return new StreamingChoice(originalBaseChoice); - } - /// /// Initializes a new instance of StreamingCompletions for tests and mocking. /// - /// The non-streaming completions to base this streaming representation on - /// The streaming choices associated with this streaming completions - /// A new instance of StreamingCompletions - public static StreamingCompletions StreamingCompletions( - Completions baseCompletions = null, - List streamingChoices = null) + /// A static set of Completions instances to use for tests and mocking. + /// A new instance of StreamingCompletions. + public static StreamingCompletions StreamingCompletions(IEnumerable completions) { - return new StreamingCompletions(baseCompletions, streamingChoices); + return new StreamingCompletions(completions); } - /// - /// Initializes a new instance of StreamingChatChoice for tests and mocking. - /// - /// An underlying ChatChoice for this streaming representation - /// A new instance of StreamingChatChoice - public static StreamingChatChoice StreamingChatChoice(ChatChoice originalBaseChoice = null) + public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate( + string id, + DateTimeOffset created, + int? choiceIndex = null, + ChatRole? role = null, + string authorName = null, + string contentUpdate = null, + CompletionsFinishReason? finishReason = null, + string functionName = null, + string functionArgumentsUpdate = null, + AzureChatExtensionsMessageContext azureExtensionsContext = null) { - return new StreamingChatChoice(originalBaseChoice); + return new StreamingChatCompletionsUpdate( + id, + created, + choiceIndex, + role, + authorName, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + azureExtensionsContext); } /// /// Initializes a new instance of StreamingChatCompletions for tests and mocking. /// - /// The non-streaming completions to base this streaming representation on - /// The streaming choices associated with this streaming chat completions + /// + /// A static collection of StreamingChatCompletionUpdates to use for tests and mocking. + /// /// A new instance of StreamingChatCompletions public static StreamingChatCompletions StreamingChatCompletions( - ChatCompletions baseChatCompletions = null, - List streamingChatChoices = null) + IEnumerable updates) { - return new StreamingChatCompletions(baseChatCompletions, streamingChatChoices); + return new StreamingChatCompletions(updates); } /// Initializes a new instance of AudioTranscription. diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs index 6c44d50e6cc19..e46691d42ae88 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs @@ -400,106 +400,6 @@ public virtual async Task> GetChatCompletionsAsync( } } - ///// - ///// Begin a chat completions request and get an object that can stream response data as it becomes - ///// available. - ///// - ///// - ///// - ///// - ///// - ///// the chat completions options for this chat completions request. - ///// - ///// - ///// a cancellation token that can be used to cancel the initial request or ongoing streaming operation. - ///// - ///// - ///// or is null. - ///// - ///// Service returned a non-success status code. - ///// The response returned from the service. - //public virtual Response GetChatCompletionsStreaming( - // string deploymentOrModelName, - // ChatCompletionsOptions chatCompletionsOptions, - // CancellationToken cancellationToken = default) - //{ - // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); - // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - - // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); - // scope.Start(); - - // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; - // chatCompletionsOptions.InternalShouldStreamResponse = true; - - // string operationPath = GetOperationPath(chatCompletionsOptions); - - // RequestContent content = chatCompletionsOptions.ToRequestContent(); - // RequestContext context = FromCancellationToken(cancellationToken); - - // try - // { - // // Response value object takes IDisposable ownership of message - // HttpMessage message = CreatePostRequestMessage( - // deploymentOrModelName, - // operationPath, - // content, - // context); - // message.BufferResponse = false; - // Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); - // } - // catch (Exception e) - // { - // scope.Failed(e); - // throw; - // } - //} - - ///// - //public virtual async Task> GetChatCompletionsStreamingAsync( - // string deploymentOrModelName, - // ChatCompletionsOptions chatCompletionsOptions, - // CancellationToken cancellationToken = default) - //{ - // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); - // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - - // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); - // scope.Start(); - - // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; - // chatCompletionsOptions.InternalShouldStreamResponse = true; - - // string operationPath = GetOperationPath(chatCompletionsOptions); - - // RequestContent content = chatCompletionsOptions.ToRequestContent(); - // RequestContext context = FromCancellationToken(cancellationToken); - - // try - // { - // // Response value object takes IDisposable ownership of message - // HttpMessage message = CreatePostRequestMessage( - // deploymentOrModelName, - // operationPath, - // content, - // context); - // message.BufferResponse = false; - // Response baseResponse = await _pipeline.ProcessMessageAsync( - // message, - // context, - // cancellationToken).ConfigureAwait(false); - // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); - // } - // catch (Exception e) - // { - // scope.Failed(e); - // throw; - // } - //} - /// /// Begin a chat completions request and get an object that can stream response data as it becomes /// available. @@ -520,7 +420,7 @@ public virtual async Task> GetChatCompletionsAsync( /// /// Service returned a non-success status code. /// The response returned from the service. - public virtual Response GetChatCompletionsStreaming2( + public virtual Response GetChatCompletionsStreaming( string deploymentOrModelName, ChatCompletionsOptions chatCompletionsOptions, CancellationToken cancellationToken = default) @@ -528,7 +428,7 @@ public virtual Response GetChatCompletionsStreaming2( Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); scope.Start(); chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; @@ -549,7 +449,7 @@ public virtual Response GetChatCompletionsStreaming2( context); message.BufferResponse = false; Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse); + return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); } catch (Exception e) { @@ -558,8 +458,8 @@ public virtual Response GetChatCompletionsStreaming2( } } - /// - public virtual async Task> GetChatCompletionsStreamingAsync2( + /// + public virtual async Task> GetChatCompletionsStreamingAsync( string deploymentOrModelName, ChatCompletionsOptions chatCompletionsOptions, CancellationToken cancellationToken = default) @@ -567,7 +467,7 @@ public virtual async Task> GetChatCompletion Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName)); Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions)); - using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming"); scope.Start(); chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName; @@ -591,7 +491,7 @@ public virtual async Task> GetChatCompletion message, context, cancellationToken).ConfigureAwait(false); - return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse); + return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); } catch (Exception e) { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs deleted file mode 100644 index 0a631b29116fc..0000000000000 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; - -namespace Azure.AI.OpenAI -{ - public class StreamingChatChoice - { - private readonly IList _baseChoices; - private readonly object _baseChoicesLock = new object(); - private readonly AsyncAutoResetEvent _updateAvailableEvent; - - /// - /// Gets the response index associated with this StreamingChoice as represented relative to other Choices - /// in the same Completions response. - /// - /// - /// Indices may be used to correlate individual Choices within a Completions result to their configured - /// prompts as provided in the request. As an example, if two Choices are requested for each of four prompts, - /// the Choices with indices 0 and 1 will correspond to the first prompt, 2 and 3 to the second, and so on. - /// - public int? Index => GetLocked(() => _baseChoices.Last().Index); - - /// - /// Gets a value representing why response generation ended when producing this StreamingChoice. - /// - /// - /// Normal termination typically provides "stop" and encountering token limits in a request typically - /// provides "length." If no value is present, this StreamingChoice is still in progress. - /// - public CompletionsFinishReason? FinishReason => GetLocked(() => _baseChoices.Last().FinishReason); - - /// - /// Information about the content filtering category (hate, sexual, violence, self_harm), if it - /// has been detected, as well as the severity level (very_low, low, medium, high-scale that - /// determines the intensity and risk level of harmful content) and if it has been filtered or not. - /// - public ContentFilterResults ContentFilterResults - => GetLocked(() => - { - return _baseChoices - .LastOrDefault(baseChoice => baseChoice.ContentFilterResults != null && baseChoice.ContentFilterResults.Hate != null) - ?.ContentFilterResults; - }); - - internal ChatMessage StreamingDeltaMessage { get; set; } - - private bool _isFinishedStreaming { get; set; } = false; - - private Exception _pumpException { get; set; } - - internal StreamingChatChoice(ChatChoice originalBaseChoice) - { - _baseChoices = new List() { originalBaseChoice }; - _updateAvailableEvent = new AsyncAutoResetEvent(); - } - - internal void UpdateFromEventStreamChatChoice(ChatChoice streamingChatChoice) - { - lock (_baseChoicesLock) - { - _baseChoices.Add(streamingChatChoice); - } - if (streamingChatChoice.FinishReason != null) - { - EnsureFinishStreaming(); - } - _updateAvailableEvent.Set(); - } - - /// - /// Gets an asynchronous enumeration that will retrieve the Completion text associated with this Choice as it - /// becomes available. Each string will provide one or more tokens, including whitespace, and a full - /// concatenation of all enumerated strings is functionally equivalent to the single Text property on a - /// non-streaming Completions Choice. - /// - /// - /// - public async IAsyncEnumerable GetMessageStreaming( - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - bool isFinalIndex = false; - for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++) - { - bool doneWaiting = false; - while (!doneWaiting) - { - lock (_baseChoicesLock) - { - ChatChoice mostRecentChoice = _baseChoices.Last(); - - doneWaiting = _isFinishedStreaming || i < _baseChoices.Count; - isFinalIndex = _isFinishedStreaming && i >= _baseChoices.Count - 1; - } - - if (!doneWaiting) - { - await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false); - } - } - - if (_pumpException != null) - { - throw _pumpException; - } - - ChatMessage message = null; - - lock (_baseChoicesLock) - { - if (i < _baseChoices.Count) - { - message = _baseChoices[i].InternalStreamingDeltaMessage; - message.Role = _baseChoices.First().InternalStreamingDeltaMessage.Role; - } - } - - if (message != null) - { - yield return message; - } - } - } - - internal void EnsureFinishStreaming(Exception pumpException = null) - { - if (!_isFinishedStreaming) - { - _isFinishedStreaming = true; - _pumpException = pumpException; - _updateAvailableEvent.Set(); - } - } - - private T GetLocked(Func func) - { - lock (_baseChoicesLock) - { - return func.Invoke(); - } - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs index 71dd993e240d8..fe868d29f6e60 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs @@ -8,7 +8,6 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; -using System.Threading.Tasks; using Azure.Core.Sse; namespace Azure.AI.OpenAI @@ -17,169 +16,66 @@ public class StreamingChatCompletions : IDisposable { private readonly Response _baseResponse; private readonly SseReader _baseResponseReader; - private readonly IList _baseChatCompletions; - private readonly object _baseCompletionsLock = new(); - private readonly IList _streamingChatChoices; - private readonly object _streamingChoicesLock = new(); - private readonly AsyncAutoResetEvent _updateAvailableEvent; - private bool _streamingTaskComplete; + private List _cachedUpdates; private bool _disposedValue; - private Exception _pumpException; - - /// - /// Gets the earliest Completion creation timestamp associated with this streamed response. - /// - public DateTimeOffset Created => GetLocked(() => _baseChatCompletions.Last().Created); - - /// - /// Gets the unique identifier associated with this streaming Completions response. - /// - public string Id => GetLocked(() => _baseChatCompletions.Last().Id); - - /// - /// Content filtering results for zero or more prompts in the request. In a streaming request, - /// results for different prompts may arrive at different times or in different orders. - /// - public IReadOnlyList PromptFilterResults - => GetLocked(() => - { - return _baseChatCompletions.Where(singleBaseChatCompletion => singleBaseChatCompletion.PromptFilterResults != null) - .SelectMany(singleBaseChatCompletion => singleBaseChatCompletion.PromptFilterResults) - .OrderBy(singleBaseCompletion => singleBaseCompletion.PromptIndex) - .ToList(); - }); internal StreamingChatCompletions(Response response) { _baseResponse = response; _baseResponseReader = new SseReader(response.ContentStream); - _updateAvailableEvent = new AsyncAutoResetEvent(); - _baseChatCompletions = new List(); - _streamingChatChoices = new List(); - _streamingTaskComplete = false; - _ = Task.Run(async () => - { - try - { - while (true) - { - SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent == null) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - throw new InvalidDataException(); - - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); - ChatCompletions chatCompletionsFromSse = ChatCompletions.DeserializeChatCompletions(sseMessageJson.RootElement); - - lock (_baseCompletionsLock) - { - _baseChatCompletions.Add(chatCompletionsFromSse); - } + _cachedUpdates = new(); + } - foreach (ChatChoice chatChoiceFromSse in chatCompletionsFromSse.Choices) - { - lock (_streamingChoicesLock) - { - StreamingChatChoice existingStreamingChoice = _streamingChatChoices - .FirstOrDefault(chatChoice => chatChoice.Index == chatChoiceFromSse.Index); - if (existingStreamingChoice == null) - { - StreamingChatChoice newStreamingChatChoice = new(chatChoiceFromSse); - _streamingChatChoices.Add(newStreamingChatChoice); - _updateAvailableEvent.Set(); - } - else - { - existingStreamingChoice.UpdateFromEventStreamChatChoice(chatChoiceFromSse); - } - } - } - } - } - catch (Exception pumpException) - { - _pumpException = pumpException; - } - finally - { - lock (_streamingChoicesLock) - { - // If anything went wrong and a StreamingChatChoice didn't naturally determine it was complete - // based on a non-null finish reason, ensure that nothing's left incomplete (and potentially - // hanging!) now. - foreach (StreamingChatChoice streamingChatChoice in _streamingChatChoices) - { - streamingChatChoice.EnsureFinishStreaming(_pumpException); - } - } - _streamingTaskComplete = true; - _updateAvailableEvent.Set(); - } - }); + internal StreamingChatCompletions(IEnumerable updates) + { + _cachedUpdates.AddRange(updates); } - public async IAsyncEnumerable GetChoicesStreaming( + public async IAsyncEnumerable EnumerateChatUpdates( [EnumeratorCancellation] CancellationToken cancellationToken = default) { - bool isFinalIndex = false; - for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++) + if (_cachedUpdates.Any()) { - bool doneWaiting = false; - while (!doneWaiting) + foreach (StreamingChatCompletionsUpdate update in _cachedUpdates) { - lock (_streamingChoicesLock) - { - doneWaiting = _streamingTaskComplete || i < _streamingChatChoices.Count; - isFinalIndex = _streamingTaskComplete && i >= _streamingChatChoices.Count - 1; - } - - if (!doneWaiting) - { - await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false); - } + yield return update; } + yield break; + } - if (_pumpException != null) + // Open clarification points: + // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? + // - What should happen if the stream is held too long and it times out? + // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? + while (!cancellationToken.IsCancellationRequested) + { + SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); + if (sseEvent == null) { - throw _pumpException; + _baseResponse.ContentStream?.Dispose(); + break; } - StreamingChatChoice newChatChoice = null; - lock (_streamingChoicesLock) + ReadOnlyMemory name = sseEvent.Value.FieldName; + if (!name.Span.SequenceEqual("data".AsSpan())) + throw new InvalidDataException(); + + ReadOnlyMemory value = sseEvent.Value.FieldValue; + if (value.Span.SequenceEqual("[DONE]".AsSpan())) { - if (i < _streamingChatChoices.Count) - { - newChatChoice = _streamingChatChoices[i]; - } + _baseResponse.ContentStream?.Dispose(); + break; } - if (newChatChoice != null) + JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); + foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate + in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement)) { - yield return newChatChoice; + _cachedUpdates.Add(streamingChatCompletionsUpdate); + yield return streamingChatCompletionsUpdate; } } - } - - internal StreamingChatCompletions( - ChatCompletions baseChatCompletions = null, - List streamingChatChoices = null) - { - _baseChatCompletions.Add(baseChatCompletions); - _streamingChatChoices = streamingChatChoices; - _streamingTaskComplete = true; + _baseResponse?.ContentStream?.Dispose(); } public void Dispose() @@ -200,13 +96,5 @@ protected virtual void Dispose(bool disposing) _disposedValue = true; } } - - private T GetLocked(Func func) - { - lock (_baseCompletionsLock) - { - return func.Invoke(); - } - } } } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs deleted file mode 100644 index 967dea227f2a7..0000000000000 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Threading; -using Azure.Core.Sse; - -namespace Azure.AI.OpenAI -{ - public class StreamingChatCompletions2 : IDisposable - { - private readonly Response _baseResponse; - private readonly SseReader _baseResponseReader; - private bool _disposedValue; - - internal StreamingChatCompletions2(Response response) - { - _baseResponse = response; - _baseResponseReader = new SseReader(response.ContentStream); - } - - public async IAsyncEnumerable EnumerateChatUpdates( - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - // Open clarification points: - // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? - // - What should happen if the stream is held too long and it times out? - // - Should enumeration be repeatable, e.g. via locally caching a collection of the updates? - // - What if the initial enumeration was cancelled? - // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? - while (!cancellationToken.IsCancellationRequested) - { - SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent == null) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - throw new InvalidDataException(); - - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); - foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate - in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement)) - { - yield return streamingChatCompletionsUpdate; - } - } - _baseResponse?.ContentStream?.Dispose(); - } - - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _baseResponseReader.Dispose(); - } - - _disposedValue = true; - } - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs deleted file mode 100644 index 51fbde8ed9a60..0000000000000 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; - -namespace Azure.AI.OpenAI -{ - public class StreamingChoice - { - private readonly IList _baseChoices; - private readonly object _baseChoicesLock = new object(); - private readonly AsyncAutoResetEvent _updateAvailableEvent; - - /// - /// Gets the response index associated with this StreamingChoice as represented relative to other Choices - /// in the same Completions response. - /// - /// - /// Indices may be used to correlate individual Choices within a Completions result to their configured - /// prompts as provided in the request. As an example, if two Choices are requested for each of four prompts, - /// the Choices with indices 0 and 1 will correspond to the first prompt, 2 and 3 to the second, and so on. - /// - public int? Index => GetLocked(() => _baseChoices.Last().Index); - - /// - /// Gets a value representing why response generation ended when producing this StreamingChoice. - /// - /// - /// Normal termination typically provides "stop" and encountering token limits in a request typically - /// provides "length." If no value is present, this StreamingChoice is still in progress. - /// - public CompletionsFinishReason? FinishReason => GetLocked(() => _baseChoices.Last().FinishReason); - - /// - /// Information about the content filtering category (hate, sexual, violence, self_harm), if it - /// has been detected, as well as the severity level (very_low, low, medium, high-scale that - /// determines the intensity and risk level of harmful content) and if it has been filtered or not. - /// - public ContentFilterResults ContentFilterResults - => GetLocked(() => - { - return _baseChoices - .LastOrDefault(baseChoice => baseChoice.ContentFilterResults != null && baseChoice.ContentFilterResults.Hate != null) - ?.ContentFilterResults; - }); - - private bool _isFinishedStreaming { get; set; } = false; - - private Exception _pumpException { get; set; } - - /// - /// Gets the log probabilities associated with tokens in this Choice. - /// - public CompletionsLogProbabilityModel LogProbabilityModel - => GetLocked(() => _baseChoices.Last().LogProbabilityModel); - - internal StreamingChoice(Choice originalBaseChoice) - { - _baseChoices = new List() { originalBaseChoice }; - _updateAvailableEvent = new AsyncAutoResetEvent(); - } - - internal void UpdateFromEventStreamChoice(Choice streamingChoice) - { - lock (_baseChoicesLock) - { - _baseChoices.Add(streamingChoice); - } - if (streamingChoice.FinishReason != null) - { - EnsureFinishStreaming(); - } - _updateAvailableEvent.Set(); - } - - /// - /// Gets an asynchronous enumeration that will retrieve the Completion text associated with this Choice as it - /// becomes available. Each string will provide one or more tokens, including whitespace, and a full - /// concatenation of all enumerated strings is functionally equivalent to the single Text property on a - /// non-streaming Completions Choice. - /// - /// - /// - public async IAsyncEnumerable GetTextStreaming( - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - bool isFinalIndex = false; - for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++) - { - bool doneWaiting = false; - while (!doneWaiting) - { - lock (_baseChoicesLock) - { - Choice mostRecentChoice = _baseChoices.Last(); - - doneWaiting = _isFinishedStreaming || i < _baseChoices.Count; - isFinalIndex = _isFinishedStreaming && i >= _baseChoices.Count - 1; - } - - if (!doneWaiting) - { - await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false); - } - } - - if (_pumpException != null) - { - throw _pumpException; - } - - string newText = string.Empty; - lock (_baseChoicesLock) - { - if (i < _baseChoices.Count) - { - newText = _baseChoices[i].Text; - } - } - - if (!string.IsNullOrEmpty(newText)) - { - yield return newText; - } - } - } - - internal void EnsureFinishStreaming(Exception pumpException = null) - { - if (!_isFinishedStreaming) - { - _isFinishedStreaming = true; - _pumpException = pumpException; - _updateAvailableEvent.Set(); - } - } - - private T GetLocked(Func func) - { - lock (_baseChoicesLock) - { - return func.Invoke(); - } - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs index 70796d90f156c..fa032a2f28f50 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs @@ -17,169 +17,63 @@ public class StreamingCompletions : IDisposable { private readonly Response _baseResponse; private readonly SseReader _baseResponseReader; - private readonly IList _baseCompletions; - private readonly object _baseCompletionsLock = new(); - private readonly IList _streamingChoices; - private readonly object _streamingChoicesLock = new(); - private readonly AsyncAutoResetEvent _updateAvailableEvent; - private bool _streamingTaskComplete; + private List _cachedCompletions; private bool _disposedValue; - private Exception _pumpException; - - /// - /// Gets the earliest Completion creation timestamp associated with this streamed response. - /// - public DateTimeOffset Created => GetLocked(() => _baseCompletions.Last().Created); - - /// - /// Gets the unique identifier associated with this streaming Completions response. - /// - public string Id => GetLocked(() => _baseCompletions.Last().Id); - - /// - /// Content filtering results for zero or more prompts in the request. In a streaming request, - /// results for different prompts may arrive at different times or in different orders. - /// - public IReadOnlyList PromptFilterResults - => GetLocked(() => - { - return _baseCompletions.Where(singleBaseCompletion => singleBaseCompletion.PromptFilterResults != null) - .SelectMany(singleBaseCompletion => singleBaseCompletion.PromptFilterResults) - .OrderBy(singleBaseCompletion => singleBaseCompletion.PromptIndex) - .ToList(); - }); internal StreamingCompletions(Response response) { _baseResponse = response; _baseResponseReader = new SseReader(response.ContentStream); - _updateAvailableEvent = new AsyncAutoResetEvent(); - _baseCompletions = new List(); - _streamingChoices = new List(); - _streamingTaskComplete = false; - _ = Task.Run(async () => - { - try - { - while (true) - { - SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent == null) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - throw new InvalidDataException(); - - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - _baseResponse.ContentStream?.Dispose(); - break; - } - - JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); - Completions completionsFromSse = Completions.DeserializeCompletions(sseMessageJson.RootElement); - - lock (_baseCompletionsLock) - { - _baseCompletions.Add(completionsFromSse); - } + _cachedCompletions = new(); + } - foreach (Choice choiceFromSse in completionsFromSse.Choices) - { - lock (_streamingChoicesLock) - { - StreamingChoice existingStreamingChoice = _streamingChoices - .FirstOrDefault(choice => choice.Index == choiceFromSse.Index); - if (existingStreamingChoice == null) - { - StreamingChoice newStreamingChoice = new(choiceFromSse); - _streamingChoices.Add(newStreamingChoice); - _updateAvailableEvent.Set(); - } - else - { - existingStreamingChoice.UpdateFromEventStreamChoice(choiceFromSse); - } - } - } - } - } - catch (Exception pumpException) - { - _pumpException = pumpException; - } - finally - { - lock (_streamingChoicesLock) - { - // If anything went wrong and a StreamingChoice didn't naturally determine it was complete - // based on a non-null finish reason, ensure that nothing's left incomplete (and potentially - // hanging!) now. - foreach (StreamingChoice streamingChoice in _streamingChoices) - { - streamingChoice.EnsureFinishStreaming(_pumpException); - } - } - _streamingTaskComplete = true; - _updateAvailableEvent.Set(); - } - }); + internal StreamingCompletions(IEnumerable completions) + { + _cachedCompletions.AddRange(completions); } - public async IAsyncEnumerable GetChoicesStreaming( + public async IAsyncEnumerable EnumerateCompletions( [EnumeratorCancellation] CancellationToken cancellationToken = default) { - bool isFinalIndex = false; - for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++) + if (_cachedCompletions.Any()) { - bool doneWaiting = false; - while (!doneWaiting) + foreach (Completions completions in _cachedCompletions) { - lock (_streamingChoicesLock) - { - doneWaiting = _streamingTaskComplete || i < _streamingChoices.Count; - isFinalIndex = _streamingTaskComplete && i >= _streamingChoices.Count - 1; - } - - if (!doneWaiting) - { - await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false); - } + yield return completions; } + yield break; + } - if (_pumpException != null) + // Open clarification points: + // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? + // - What should happen if the stream is held too long and it times out? + // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? + while (!cancellationToken.IsCancellationRequested) + { + SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); + if (sseEvent == null) { - throw _pumpException; + _baseResponse.ContentStream?.Dispose(); + break; } - StreamingChoice newChoice = null; - lock (_streamingChoicesLock) - { - if (i < _streamingChoices.Count) - { - newChoice = _streamingChoices[i]; - } - } + ReadOnlyMemory name = sseEvent.Value.FieldName; + if (!name.Span.SequenceEqual("data".AsSpan())) + throw new InvalidDataException(); - if (newChoice != null) + ReadOnlyMemory value = sseEvent.Value.FieldValue; + if (value.Span.SequenceEqual("[DONE]".AsSpan())) { - yield return newChoice; + _baseResponse.ContentStream?.Dispose(); + break; } - } - } - internal StreamingCompletions( - Completions baseCompletions = null, - IList streamingChoices = null) - { - _baseCompletions.Add(baseCompletions); - _streamingChoices = streamingChoices; - _streamingTaskComplete = true; + JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); + Completions sseCompletions = Completions.DeserializeCompletions(sseMessageJson.RootElement); + _cachedCompletions.Add(sseCompletions); + yield return sseCompletions; + } + _baseResponse?.ContentStream?.Dispose(); } public void Dispose() @@ -200,13 +94,5 @@ protected virtual void Dispose(bool disposing) _disposedValue = true; } } - - private T GetLocked(Func func) - { - lock (_baseCompletionsLock) - { - return func.Invoke(); - } - } } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs index 8f913da5a4dfd..c72a2abcfdead 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs @@ -186,13 +186,13 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic }, }; - Response response = await client.GetChatCompletionsStreamingAsync2( + Response response = await client.GetChatCompletionsStreamingAsync( deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); Assert.That(response.Value, Is.Not.Null); - using StreamingChatCompletions2 streamingChatCompletions = response.Value; + using StreamingChatCompletions streamingChatCompletions = response.Value; ChatRole? streamedRole = null; IEnumerable azureContextMessages = null; diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs index c7d3289706775..1f3a2d72ffb3b 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs @@ -154,11 +154,11 @@ public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceT MaxTokens = 512, }; - Response response - = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions); + Response response + = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); - using StreamingChatCompletions2 streamingChatCompletions = response.Value; + using StreamingChatCompletions streamingChatCompletions = response.Value; ChatRole? streamedRole = default; string functionName = null; diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index 5b8ba0ba9f42e..e5eabcaa1a50c 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -293,11 +293,11 @@ public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTar }, MaxTokens = 512, }; - Response streamingResponse - = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions); + Response streamingResponse + = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); Assert.That(streamingResponse, Is.Not.Null); - using StreamingChatCompletions2 streamingChatCompletions = streamingResponse.Value; - Assert.That(streamingChatCompletions, Is.InstanceOf()); + using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value; + Assert.That(streamingChatCompletions, Is.InstanceOf()); StringBuilder contentBuilder = new StringBuilder(); bool gotRole = false; @@ -440,9 +440,6 @@ public async Task TokenCutoff(OpenAIClientServiceTarget serviceTarget) [TestCase(OpenAIClientServiceTarget.NonAzure)] public async Task StreamingCompletions(OpenAIClientServiceTarget serviceTarget) { - // Temporary test note: at the time of authoring, content filter results aren't included for completions - // with the latest 2023-07-01-preview service API version. We'll manually configure one version behind - // pending the latest version having these results enabled. OpenAIClient client = GetTestClient(serviceTarget); string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.LegacyCompletions); @@ -467,51 +464,68 @@ public async Task StreamingCompletions(OpenAIClientServiceTarget serviceTarget) // implement IDisposable or otherwise ensure that an `IDisposable` underlying `.Value` is disposed. using StreamingCompletions responseValue = response.Value; - int originallyEnumeratedChoices = 0; - var originallyEnumeratedTextParts = new List>(); + Dictionary promptFilterResultsByPromptIndex = new(); + Dictionary finishReasonsByChoiceIndex = new(); + Dictionary textBuildersByChoiceIndex = new(); - await foreach (StreamingChoice choice in responseValue.GetChoicesStreaming()) + await foreach (Completions streamingCompletions in responseValue.EnumerateCompletions()) { - List textPartsForChoice = new(); - StringBuilder choiceTextBuilder = new(); - await foreach (string choiceTextPart in choice.GetTextStreaming()) + if (streamingCompletions.PromptFilterResults is not null) { - choiceTextBuilder.Append(choiceTextPart); - textPartsForChoice.Add(choiceTextPart); + foreach (PromptFilterResult promptFilterResult in streamingCompletions.PromptFilterResults) + { + // When providing multiple prompts, the filter results may arrive across separate messages and + // the payload array index may differ from the in-data index property + AssertExpectedContentFilterResults(promptFilterResult.ContentFilterResults, serviceTarget); + promptFilterResultsByPromptIndex[promptFilterResult.PromptIndex] = promptFilterResult; + } + } + foreach (Choice choice in streamingCompletions.Choices) + { + if (choice.FinishReason.HasValue) + { + // Each choice should only receive a single finish reason and, in this case, it should be + // 'stop' + Assert.That(!finishReasonsByChoiceIndex.ContainsKey(choice.Index)); + Assert.That(choice.FinishReason.Value == CompletionsFinishReason.Stopped); + finishReasonsByChoiceIndex[choice.Index] = choice.FinishReason.Value; + } + + // The 'Text' property of streamed Completions will only contain the incremental update for a + // choice; these should be appended to each other to form the complete text + if (!textBuildersByChoiceIndex.ContainsKey(choice.Index)) + { + textBuildersByChoiceIndex[choice.Index] = new(); + } + textBuildersByChoiceIndex[choice.Index].Append(choice.Text); + + Assert.That(choice.LogProbabilityModel, Is.Not.Null); + + if (!string.IsNullOrEmpty(choice.Text)) + { + // Content filter results are only populated when content (text) is present + AssertExpectedContentFilterResults(choice.ContentFilterResults, serviceTarget); + } } - Assert.That(choiceTextBuilder.ToString(), Is.Not.Null.Or.Empty); - Assert.That(choice.LogProbabilityModel, Is.Not.Null); - AssertExpectedContentFilterResults(choice.ContentFilterResults, serviceTarget); - originallyEnumeratedChoices++; - originallyEnumeratedTextParts.Add(textPartsForChoice); } - // Note: these top-level values *are likely not yet populated* until *after* at least one streaming - // choice has arrived. - Assert.That(response.GetRawResponse(), Is.Not.Null.Or.Empty); - Assert.That(responseValue.Id, Is.Not.Null.Or.Empty); - Assert.That(responseValue.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1)))); - Assert.That(responseValue.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7))); + int expectedPromptCount = requestOptions.Prompts.Count; + int expectedChoiceCount = expectedPromptCount * (requestOptions.ChoicesPerPrompt ?? 1); - AssertExpectedPromptFilterResults( - responseValue.PromptFilterResults, - serviceTarget, - expectedCount: requestOptions.Prompts.Count * (requestOptions.ChoicesPerPrompt ?? 1)); - - // Validate stability of enumeration (non-cancelled case) - IReadOnlyList secondPassChoices = await GetBlockingListFromIAsyncEnumerable( - responseValue.GetChoicesStreaming()); - Assert.AreEqual(originallyEnumeratedChoices, secondPassChoices.Count); - for (int i = 0; i < secondPassChoices.Count; i++) + if (serviceTarget == OpenAIClientServiceTarget.Azure) { - IReadOnlyList secondPassTextParts = await GetBlockingListFromIAsyncEnumerable( - secondPassChoices[i].GetTextStreaming()); - Assert.AreEqual(originallyEnumeratedTextParts[i].Count, secondPassTextParts.Count); - for (int j = 0; j < originallyEnumeratedTextParts[i].Count; j++) + for (int i = 0; i < expectedPromptCount; i++) { - Assert.AreEqual(originallyEnumeratedTextParts[i][j], secondPassTextParts[j]); + Assert.That(promptFilterResultsByPromptIndex.ContainsKey(i)); } } + + for (int i = 0; i < expectedChoiceCount; i++) + { + Assert.That(finishReasonsByChoiceIndex.ContainsKey(i)); + Assert.That(textBuildersByChoiceIndex.ContainsKey(i)); + Assert.That(textBuildersByChoiceIndex[i].ToString(), Is.Not.Null.Or.Empty); + } } [RecordedTest] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs index bd0ee5d1db4e1..8bee03f306da5 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs @@ -11,39 +11,6 @@ namespace Azure.AI.OpenAI.Tests.Samples { public partial class StreamingChat { - //[Test] - //[Ignore("Only verifying that the sample builds")] - //public async Task StreamingChatWithNonAzureOpenAI() - //{ - // #region Snippet:StreamChatMessages - // string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com"; - // var client = new OpenAIClient(nonAzureOpenAIApiKey, new OpenAIClientOptions()); - // var chatCompletionsOptions = new ChatCompletionsOptions() - // { - // Messages = - // { - // new ChatMessage(ChatRole.System, "You are a helpful assistant. You will talk like a pirate."), - // new ChatMessage(ChatRole.User, "Can you help me?"), - // new ChatMessage(ChatRole.Assistant, "Arrrr! Of course, me hearty! What can I do for ye?"), - // new ChatMessage(ChatRole.User, "What's the best way to train a parrot?"), - // } - // }; - - // Response response = await client.GetChatCompletionsStreamingAsync( - // deploymentOrModelName: "gpt-3.5-turbo", - // chatCompletionsOptions); - // using StreamingChatCompletions streamingChatCompletions = response.Value; - - // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) - // { - // await foreach (ChatMessage message in choice.GetMessageStreaming()) - // { - // Console.Write(message.Content); - // } - // Console.WriteLine(); - // } - // #endregion - //} [Test] [Ignore("Only verifying that the sample builds")] public async Task StreamingChatWithNonAzureOpenAI() @@ -62,10 +29,10 @@ public async Task StreamingChatWithNonAzureOpenAI() } }; - Response response = await client.GetChatCompletionsStreamingAsync2( + Response response = await client.GetChatCompletionsStreamingAsync( deploymentOrModelName: "gpt-3.5-turbo", chatCompletionsOptions); - using StreamingChatCompletions2 streamingChatCompletions = response.Value; + using StreamingChatCompletions streamingChatCompletions = response.Value; await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates()) From 9d2b330031ffac79d01c53bae330bfc5277546eb Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Wed, 18 Oct 2023 17:13:27 -0700 Subject: [PATCH 05/19] comments, tests, and other cleanup --- .../src/Custom/StreamingChatCompletions.cs | 54 +++++++- .../Custom/StreamingChatCompletionsUpdate.cs | 120 +++++++++++++++++- .../src/Custom/StreamingCompletions.cs | 61 ++++++++- .../tests/ChatFunctionsTests.cs | 53 +------- .../OpenAIInferenceModelFactoryTests.cs.cs | 34 +++++ .../tests/OpenAIInferenceTests.cs | 52 +------- 6 files changed, 255 insertions(+), 119 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs index fe868d29f6e60..4da4215ef7751 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs @@ -12,13 +12,24 @@ namespace Azure.AI.OpenAI { + /// + /// Represents a streaming source of instances that allow an + /// application to incrementally use and display text and other information as it becomes available during a + /// generation operation. + /// public class StreamingChatCompletions : IDisposable { private readonly Response _baseResponse; private readonly SseReader _baseResponseReader; - private List _cachedUpdates; + private readonly List _cachedUpdates; private bool _disposedValue; + /// + /// Creates a new instance of based on a REST response stream. + /// + /// + /// The from which server-sent events will be consumed. + /// internal StreamingChatCompletions(Response response) { _baseResponse = response; @@ -26,11 +37,42 @@ internal StreamingChatCompletions(Response response) _cachedUpdates = new(); } + /// + /// Creates a new instance of based on an existing set of + /// instances. Typically used for tests or mocking. + /// + /// + /// An existing collection of instances to use. Update enumeration + /// will yield these instances instead of instances derived from a response stream. + /// internal StreamingChatCompletions(IEnumerable updates) { - _cachedUpdates.AddRange(updates); + _cachedUpdates = new(updates); } + /// + /// Returns an asynchronous enumeration of instances as they + /// become available via the associated network response or other data source. + /// + /// + /// This collection may be iterated over using the "await foreach" statement. + /// + /// + /// + /// An optional used to end enumeration before it would normally complete. + /// + /// + /// Cancellation will immediately close any underlying network response stream and may consequently limit + /// incurred token generation and associated cost. + /// + /// + /// + /// An asynchronous enumeration of instances. + /// + /// + /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not + /// prefixed with the expected 'data:' tag. + /// public async IAsyncEnumerable EnumerateChatUpdates( [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -52,7 +94,7 @@ public async IAsyncEnumerable EnumerateChatUpdat SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); if (sseEvent == null) { - _baseResponse.ContentStream?.Dispose(); + _baseResponse?.ContentStream?.Dispose(); break; } @@ -63,7 +105,7 @@ public async IAsyncEnumerable EnumerateChatUpdat ReadOnlyMemory value = sseEvent.Value.FieldValue; if (value.Span.SequenceEqual("[DONE]".AsSpan())) { - _baseResponse.ContentStream?.Dispose(); + _baseResponse?.ContentStream?.Dispose(); break; } @@ -78,19 +120,21 @@ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sse _baseResponse?.ContentStream?.Dispose(); } + /// public void Dispose() { Dispose(disposing: true); GC.SuppressFinalize(this); } + /// protected virtual void Dispose(bool disposing) { if (!_disposedValue) { if (disposing) { - _baseResponseReader.Dispose(); + _baseResponseReader?.Dispose(); } _disposedValue = true; diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs index c16e0d20f2e5d..f28ddcf5f94fd 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs @@ -2,34 +2,144 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Threading; -using Azure.Core; namespace Azure.AI.OpenAI { public partial class StreamingChatCompletionsUpdate { + /// + /// Gets a unique identifier associated with this streamed Chat Completions response. + /// + /// + /// + /// Corresponds to $.id in the underlying REST schema. + /// + /// When using Azure OpenAI, note that the values of and may not be + /// populated until the first containing role, content, or + /// function information. + /// public string Id { get; } + /// + /// Gets the first timestamp associated with generation activity for this completions response, + /// represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. + /// + /// + /// + /// Corresponds to $.created in the underlying REST schema. + /// + /// When using Azure OpenAI, note that the values of and may not be + /// populated until the first containing role, content, or + /// function information. + /// public DateTimeOffset Created { get; } + /// + /// Gets the associated with this update. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].delta.role in the underlying REST schema. + /// + /// assignment typically occurs in a single update across a streamed Chat Completions + /// choice and the value should be considered to be persist for all subsequent updates without a + /// that bear the same . + /// public ChatRole? Role { get; } + /// + /// Gets the content fragment associated with this update. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].delta.content in the underlying REST schema. + /// + /// Each update contains only a small number of tokens. When presenting or reconstituting a full, streamed + /// response, all values for the same should be + /// combined. + /// public string ContentUpdate { get; } + /// + /// Gets the name of a function to be called. + /// + /// + /// Corresponds to e.g. $.choices[0].delta.function_call.name in the underlying REST schema. + /// public string FunctionName { get; } + /// + /// Gets a function arguments fragment associated with this update. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].delta.function_call.arguments in the underlying REST schema. + /// + /// + /// + /// Each update contains only a small number of tokens. When presenting or reconstituting a full, streamed + /// arguments body, all values for the same + /// should be combined. + /// + /// + /// + /// As is the case for non-streaming , the content provided for function + /// arguments is not guaranteed to be well-formed JSON or to contain expected data. Use of function arguments + /// should validate before consumption. + /// + /// public string FunctionArgumentsUpdate { get; } + /// + /// Gets an optional name associated with the role of the streamed Chat Completion, typically as previously + /// specified in a system message. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].delta.name in the underlying REST schema. + /// + /// public string AuthorName { get; } + /// + /// Gets the associated with this update. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].finish_reason in the underlying REST schema. + /// + /// + /// assignment typically appears in the final streamed update message associated + /// with a choice. + /// + /// public CompletionsFinishReason? FinishReason { get; } + /// + /// Gets the choice index associated with this streamed update. + /// + /// + /// + /// Corresponds to e.g. $.choices[0].index in the underlying REST schema. + /// + /// + /// Unless a value greater than one was provided to ('n' in + /// REST), only one choice will be generated. In that case, this value will always be 0 and may not need to be + /// considered. + /// + /// + /// When providing a value greater than one to , this index + /// identifies which logical response the payload is associated with. In the event that a single underlying + /// server-sent event contains multiple choices, multiple instances of + /// will be created. + /// + /// public int? ChoiceIndex { get; } + /// + /// Gets additional data associated with this update that is specific to use with Azure OpenAI and that + /// service's extended capabilities. + /// public AzureChatExtensionsMessageContext AzureExtensionsContext { get; } internal StreamingChatCompletionsUpdate( diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs index fa032a2f28f50..ca1e98923f136 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs @@ -8,18 +8,28 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; -using System.Threading.Tasks; using Azure.Core.Sse; namespace Azure.AI.OpenAI { + /// + /// Represents a streaming source of instances that allow an application to incrementally + /// use and display text and other information as it becomes available during a generation operation. + /// public class StreamingCompletions : IDisposable { private readonly Response _baseResponse; private readonly SseReader _baseResponseReader; - private List _cachedCompletions; + private readonly List _cachedCompletions; private bool _disposedValue; + /// + /// Creates a new instance of that will generate updates from a + /// service response stream. + /// + /// + /// The instance from which server-sent events will be parsed. + /// internal StreamingCompletions(Response response) { _baseResponse = response; @@ -27,11 +37,50 @@ internal StreamingCompletions(Response response) _cachedCompletions = new(); } + /// + /// Creates a new instance of that will yield instances from a provided set + /// of instances. Typically used for tests and mocking. + /// + /// + /// An existing set of instances to use. + /// internal StreamingCompletions(IEnumerable completions) { - _cachedCompletions.AddRange(completions); + _cachedCompletions = new(completions); } + /// + /// Returns an asynchronous enumeration of instances as they + /// become available via the associated network response or other data source. + /// + /// + /// + /// This collection may be iterated over using the "await foreach" statement. + /// + /// + /// Note that, in contrast to , streamed share + /// the same underlying representation and object model as their non-streamed counterpart. The + /// property on a streamed instanced will only contain a + /// small number of new tokens and must be combined with other values to represent + /// a full, reconsistuted message. + /// + /// + /// + /// + /// An optional used to end enumeration before it would normally complete. + /// + /// + /// Cancellation will immediately close any underlying network response stream and may consequently limit + /// incurred token generation and associated cost. + /// + /// + /// + /// An asynchronous enumeration of instances. + /// + /// + /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not + /// prefixed with the expected 'data:' tag. + /// public async IAsyncEnumerable EnumerateCompletions( [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -53,7 +102,7 @@ public async IAsyncEnumerable EnumerateCompletions( SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); if (sseEvent == null) { - _baseResponse.ContentStream?.Dispose(); + _baseResponse?.ContentStream?.Dispose(); break; } @@ -64,7 +113,7 @@ public async IAsyncEnumerable EnumerateCompletions( ReadOnlyMemory value = sseEvent.Value.FieldValue; if (value.Span.SequenceEqual("[DONE]".AsSpan())) { - _baseResponse.ContentStream?.Dispose(); + _baseResponse?.ContentStream?.Dispose(); break; } @@ -88,7 +137,7 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - _baseResponseReader.Dispose(); + _baseResponseReader?.Dispose(); } _disposedValue = true; diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs index 1f3a2d72ffb3b..fd7b345ef854a 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs @@ -84,61 +84,10 @@ public async Task SimpleFunctionCallWorks(OpenAIClientServiceTarget serviceTarge Assert.That(followupResponse.Value.Choices[0].Message.Content, Is.Not.Null.Or.Empty); } - //[RecordedTest] - //[TestCase(OpenAIClientServiceTarget.Azure)] - //[TestCase(OpenAIClientServiceTarget.NonAzure)] - //public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget) - //{ - // OpenAIClient client = GetTestClient(serviceTarget); - // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions); - - // var requestOptions = new ChatCompletionsOptions() - // { - // Functions = { s_futureTemperatureFunction }, - // Messages = - // { - // new ChatMessage(ChatRole.System, "You are a helpful assistant."), - // new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"), - // }, - // MaxTokens = 512, - // }; - - // Response response - // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); - // Assert.That(response, Is.Not.Null); - - // using StreamingChatCompletions streamingChatCompletions = response.Value; - - // ChatRole streamedRole = default; - // string functionName = default; - // StringBuilder argumentsBuilder = new(); - - // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming()) - // { - // await foreach (ChatMessage message in choice.GetMessageStreaming()) - // { - // if (message.Role != default) - // { - // streamedRole = message.Role; - // } - // if (message.FunctionCall?.Name != null) - // { - // Assert.That(functionName, Is.Null.Or.Empty); - // functionName = message.FunctionCall.Name; - // } - // argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty); - // } - // } - - // Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant)); - // Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name)); - // Assert.That(argumentsBuilder.Length, Is.GreaterThan(0)); - //} - [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] [TestCase(OpenAIClientServiceTarget.NonAzure)] - public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceTarget) + public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget) { OpenAIClient client = GetTestClient(serviceTarget); string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions); diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs index ea9e18667a16a..30401cad8ce50 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text; +using System.Threading.Tasks; using NUnit.Framework; namespace Azure.AI.OpenAI.Tests @@ -180,5 +182,37 @@ public void TestAudioTranslation() Assert.That(audioTranslation.Segments, Is.Not.Null.Or.Empty); Assert.That(audioTranslation.Segments.Count, Is.EqualTo(1)); } + + [Test] + public async Task TestStreamingChatCompletions() + { + const string expectedId = "expected-id-value"; + + StreamingChatCompletionsUpdate firstUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + role: ChatRole.Assistant, + contentUpdate: "hello"); + StreamingChatCompletionsUpdate secondUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + contentUpdate: " world"); + StreamingChatCompletionsUpdate thirdUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + finishReason: CompletionsFinishReason.Stopped); + + using StreamingChatCompletions streamingChatCompletions = AzureOpenAIModelFactory.StreamingChatCompletions( + new[] { firstUpdate, secondUpdate, thirdUpdate }); + + StringBuilder contentBuilder = new(); + await foreach (StreamingChatCompletionsUpdate update in streamingChatCompletions.EnumerateChatUpdates()) + { + Assert.That(update.Id == expectedId); + Assert.That(update.Created > new DateTimeOffset(new DateTime(2023, 1, 1))); + contentBuilder.Append(update.ContentUpdate); + } + Assert.That(contentBuilder.ToString() == "hello world"); + } } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index e5eabcaa1a50c..951580c9d7c3f 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -223,60 +223,10 @@ string deploymentOrModelName AssertExpectedContentFilterResults(firstChoice.ContentFilterResults, serviceTarget); } - //[RecordedTest] - //[TestCase(OpenAIClientServiceTarget.Azure)] - //[TestCase(OpenAIClientServiceTarget.NonAzure)] - //public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget) - //{ - // OpenAIClient client = GetTestClient(serviceTarget); - // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( - // serviceTarget, - // OpenAIClientScenario.ChatCompletions); - // var requestOptions = new ChatCompletionsOptions() - // { - // Messages = - // { - // new ChatMessage(ChatRole.System, "You are a helpful assistant."), - // new ChatMessage(ChatRole.User, "Can you help me?"), - // new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"), - // new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"), - // }, - // MaxTokens = 512, - // }; - // Response streamingResponse - // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); - // Assert.That(streamingResponse, Is.Not.Null); - // using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value; - // Assert.That(streamingChatCompletions, Is.InstanceOf()); - - // int totalMessages = 0; - - // await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming()) - // { - // Assert.That(streamingChoice, Is.Not.Null); - // await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming()) - // { - // Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant)); - // totalMessages++; - // } - // AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget); - // } - - // Assert.That(totalMessages, Is.GreaterThan(1)); - - // // Note: these top-level values *are likely not yet populated* until *after* at least one streaming - // // choice has arrived. - // Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty); - // Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty); - // Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1)))); - // Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7))); - // AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1)); - //} - [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] [TestCase(OpenAIClientServiceTarget.NonAzure)] - public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTarget) + public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget) { OpenAIClient client = GetTestClient(serviceTarget); string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( From 2e3f43fe8dea77730ded99bbbfd50620d2f6b652 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Wed, 18 Oct 2023 17:15:59 -0700 Subject: [PATCH 06/19] one orphaned test comment cleanup --- .../tests/AzureChatExtensionsTests.cs | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs index c72a2abcfdead..96b1c56c7bff4 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs @@ -89,68 +89,6 @@ public async Task BasicSearchExtensionWorks( Assert.That(context.Messages.First().Content.Contains("citations")); } - //[RecordedTest] - //[TestCase(OpenAIClientServiceTarget.Azure)] - //public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget) - //{ - // OpenAIClient client = GetTestClient(serviceTarget); - // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName( - // serviceTarget, - // OpenAIClientScenario.ChatCompletions); - - // var requestOptions = new ChatCompletionsOptions() - // { - // Messages = - // { - // new ChatMessage(ChatRole.User, "What does PR complete mean?"), - // }, - // MaxTokens = 512, - // AzureExtensionsOptions = new() - // { - // Extensions = - // { - // new AzureChatExtensionConfiguration() - // { - // Type = "AzureCognitiveSearch", - // Parameters = BinaryData.FromObjectAsJson(new - // { - // Endpoint = "https://openaisdktestsearch.search.windows.net", - // IndexName = "openai-test-index-carbon-wiki", - // Key = GetCognitiveSearchApiKey().Key, - // }, - // new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }), - // }, - // } - // }, - // }; - - // Response response = await client.GetChatCompletionsStreamingAsync( - // deploymentOrModelName, - // requestOptions); - // Assert.That(response, Is.Not.Null); - // Assert.That(response.Value, Is.Not.Null); - - // using StreamingChatCompletions streamingChatCompletions = response.Value; - - // int choiceCount = 0; - // List messageChunks = new(); - - // await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming()) - // { - // choiceCount++; - // await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming()) - // { - // messageChunks.Add(chatMessage); - // } - // } - - // Assert.That(choiceCount, Is.EqualTo(1)); - // Assert.That(messageChunks, Is.Not.Null.Or.Empty); - // Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool))); - // //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant)); - // Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content))); - //} - [RecordedTest] [TestCase(OpenAIClientServiceTarget.Azure)] public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget) From 91333589f303a9ff24775dfcc052f60d6b7b2b46 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Wed, 18 Oct 2023 17:33:53 -0700 Subject: [PATCH 07/19] xml comment fix --- .../Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs | 1 - sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs index 4da4215ef7751..ff6a184e318e3 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs @@ -127,7 +127,6 @@ public void Dispose() GC.SuppressFinalize(this); } - /// protected virtual void Dispose(bool disposing) { if (!_disposedValue) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs index ca1e98923f136..a26baeb5631a7 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs @@ -125,6 +125,7 @@ public async IAsyncEnumerable EnumerateCompletions( _baseResponse?.ContentStream?.Dispose(); } + /// public void Dispose() { Dispose(disposing: true); From 6b520eafa509035d64a059bdc539576768e1132e Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Thu, 19 Oct 2023 12:14:47 -0700 Subject: [PATCH 08/19] test assets, making it real --- sdk/openai/Azure.AI.OpenAI/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/assets.json b/sdk/openai/Azure.AI.OpenAI/assets.json index 9b85d4de7bc32..3aea2fc197d1d 100644 --- a/sdk/openai/Azure.AI.OpenAI/assets.json +++ b/sdk/openai/Azure.AI.OpenAI/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/openai/Azure.AI.OpenAI", - "Tag": "net/openai/Azure.AI.OpenAI_52e82965d8" + "Tag": "net/openai/Azure.AI.OpenAI_2ea3f2ca28" } From 5694c900f8341cf0f02c20c85eb768888ed11491 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Thu, 19 Oct 2023 13:06:41 -0700 Subject: [PATCH 09/19] test assets pt. 2 (re-record functions) --- sdk/openai/Azure.AI.OpenAI/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/assets.json b/sdk/openai/Azure.AI.OpenAI/assets.json index 3aea2fc197d1d..3a4237461fe9c 100644 --- a/sdk/openai/Azure.AI.OpenAI/assets.json +++ b/sdk/openai/Azure.AI.OpenAI/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/openai/Azure.AI.OpenAI", - "Tag": "net/openai/Azure.AI.OpenAI_2ea3f2ca28" + "Tag": "net/openai/Azure.AI.OpenAI_86a53abc90" } From e1b24f646a893bebeeba3f29ef5da64c576da060 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Fri, 20 Oct 2023 14:44:09 -0700 Subject: [PATCH 10/19] revised pattern using new StreamingResponse --- sdk/openai/Azure.AI.OpenAI/README.md | 8 +- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 36 ++--- .../src/Custom/AzureOpenAIModelFactory.cs | 29 +--- .../src/Custom/OpenAIClient.cs | 63 ++++++-- .../src/Custom/StreamingChatCompletions.cs | 143 ----------------- .../Custom/StreamingChatCompletionsUpdate.cs | 13 +- .../src/Custom/StreamingCompletions.cs | 148 ------------------ .../src/Helpers/SseAsyncEnumerator.cs | 63 ++++++++ .../Azure.AI.OpenAI/src/Helpers/SseReader.cs | 8 +- .../src/Helpers/StreamingResponse.cs | 70 +++++++++ .../tests/AzureChatExtensionsTests.cs | 15 +- .../tests/ChatFunctionsTests.cs | 8 +- .../OpenAIInferenceModelFactoryTests.cs.cs | 41 +++-- .../tests/OpenAIInferenceTests.cs | 31 +--- .../tests/Samples/Sample04_StreamingChat.cs | 10 +- 15 files changed, 256 insertions(+), 430 deletions(-) delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index cd8a7d002be92..0fd3e91de72be 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -206,13 +206,9 @@ var chatCompletionsOptions = new ChatCompletionsOptions() } }; -Response response = await client.GetChatCompletionsStreamingAsync( +await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming( deploymentOrModelName: "gpt-3.5-turbo", - chatCompletionsOptions); -using StreamingChatCompletions streamingChatCompletions = response.Value; - -await foreach (StreamingChatCompletionsUpdate chatUpdate - in streamingChatCompletions.EnumerateChatUpdates()) + chatCompletionsOptions)) { if (chatUpdate.Role.HasValue) { diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index 7a2dfac1a77bc..c74ba62e240da 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -1,3 +1,15 @@ +namespace Azure +{ + public partial class StreamingResponse : System.Collections.Generic.IAsyncEnumerable, System.IDisposable + { + internal StreamingResponse() { } + public void Dispose() { } + protected virtual void Dispose(bool disposing) { } + public System.Collections.Generic.IAsyncEnumerable EnumerateValues() { throw null; } + public Azure.Response GetRawResponse() { throw null; } + System.Collections.Generic.IAsyncEnumerator System.Collections.Generic.IAsyncEnumerable.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; } + } +} namespace Azure.AI.OpenAI { public partial class AudioTranscription @@ -206,9 +218,7 @@ public static partial class AzureOpenAIModelFactory public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable data = null) { throw null; } public static Azure.AI.OpenAI.ImageLocation ImageLocation(System.Uri url = null) { throw null; } public static Azure.AI.OpenAI.PromptFilterResult PromptFilterResult(int promptIndex = 0, Azure.AI.OpenAI.ContentFilterResults contentFilterResults = null) { throw null; } - public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(System.Collections.Generic.IEnumerable updates) { throw null; } public static Azure.AI.OpenAI.StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(string id, System.DateTimeOffset created, int? choiceIndex = default(int?), Azure.AI.OpenAI.ChatRole? role = default(Azure.AI.OpenAI.ChatRole?), string authorName = null, string contentUpdate = null, Azure.AI.OpenAI.CompletionsFinishReason? finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason?), string functionName = null, string functionArgumentsUpdate = null, Azure.AI.OpenAI.AzureChatExtensionsMessageContext azureExtensionsContext = null) { throw null; } - public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(System.Collections.Generic.IEnumerable completions) { throw null; } } public partial class ChatChoice { @@ -479,14 +489,14 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent public virtual System.Threading.Tasks.Task> GetAudioTranslationAsync(string deploymentId, Azure.AI.OpenAI.AudioTranslationOptions audioTranslationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.StreamingResponse GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.Response GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.StreamingResponse GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetEmbeddings(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetEmbeddingsAsync(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetImageGenerations(Azure.AI.OpenAI.ImageGenerationOptions imageGenerationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -511,13 +521,6 @@ internal PromptFilterResult() { } public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } } public int PromptIndex { get { throw null; } } } - public partial class StreamingChatCompletions : System.IDisposable - { - internal StreamingChatCompletions() { } - public void Dispose() { } - protected virtual void Dispose(bool disposing) { } - public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - } public partial class StreamingChatCompletionsUpdate { internal StreamingChatCompletionsUpdate() { } @@ -532,13 +535,6 @@ internal StreamingChatCompletionsUpdate() { } public string Id { get { throw null; } } public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } } } - public partial class StreamingCompletions : System.IDisposable - { - internal StreamingCompletions() { } - public void Dispose() { } - protected virtual void Dispose(bool disposing) { } - public System.Collections.Generic.IAsyncEnumerable EnumerateCompletions([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - } } namespace Microsoft.Extensions.Azure { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs index cc90ea19685df..9cdbf2b7d2767 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs @@ -14,11 +14,11 @@ namespace Azure.AI.OpenAI public static partial class AzureOpenAIModelFactory { /// Initializes a new instance of ChatChoice. - /// The chat message associated with this chat completions choice + /// The chat message associated with this chat completions choice. /// The ordered index associated with this chat completions choice. /// The reason that this chat completions choice completed its generated. - /// For streamed choices, the internal representation of a 'delta' payload - /// The category annotations for this chat choice's content filtering + /// For streamed choices, the internal representation of a 'delta' payload. + /// The category annotations for this chat choice's content filtering. /// A new instance for mocking. public static ChatChoice ChatChoice( ChatMessage message = null, @@ -30,16 +30,6 @@ public static ChatChoice ChatChoice( return new ChatChoice(message, index, finishReason, deltaMessage, contentFilterResults); } - /// - /// Initializes a new instance of StreamingCompletions for tests and mocking. - /// - /// A static set of Completions instances to use for tests and mocking. - /// A new instance of StreamingCompletions. - public static StreamingCompletions StreamingCompletions(IEnumerable completions) - { - return new StreamingCompletions(completions); - } - public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate( string id, DateTimeOffset created, @@ -65,19 +55,6 @@ public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate( azureExtensionsContext); } - /// - /// Initializes a new instance of StreamingChatCompletions for tests and mocking. - /// - /// - /// A static collection of StreamingChatCompletionUpdates to use for tests and mocking. - /// - /// A new instance of StreamingChatCompletions - public static StreamingChatCompletions StreamingChatCompletions( - IEnumerable updates) - { - return new StreamingChatCompletions(updates); - } - /// Initializes a new instance of AudioTranscription. /// Transcribed text. /// Language detected in the source audio file. diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs index e46691d42ae88..15fcfc323d50a 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs @@ -4,10 +4,12 @@ #nullable disable using System; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Azure.Core; using Azure.Core.Pipeline; +using Azure.Core.Sse; namespace Azure.AI.OpenAI { @@ -239,9 +241,11 @@ public virtual Task> GetCompletionsAsync( /// /// Service returned a non-success status code. /// - /// A response that, if the request was successful, includes a instance. + /// A response that, if the request was successful, may be asynchronously enumerated for + /// instances. /// - public virtual Response GetCompletionsStreaming( + [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")] + public virtual StreamingResponse GetCompletionsStreaming( string deploymentOrModelName, CompletionsOptions completionsOptions, CancellationToken cancellationToken = default) @@ -268,7 +272,12 @@ public virtual Response GetCompletionsStreaming( context); message.BufferResponse = false; Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return Response.FromValue(new StreamingCompletions(baseResponse), baseResponse); + return new StreamingResponse( + baseResponse, + SseAsyncEnumerator.EnumerateFromSseStream( + baseResponse.ContentStream, + Completions.DeserializeCompletions, + cancellationToken)); } catch (Exception e) { @@ -278,7 +287,8 @@ public virtual Response GetCompletionsStreaming( } /// - public virtual async Task> GetCompletionsStreamingAsync( + [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")] + public virtual async Task> GetCompletionsStreamingAsync( string deploymentOrModelName, CompletionsOptions completionsOptions, CancellationToken cancellationToken = default) @@ -306,7 +316,12 @@ public virtual async Task> GetCompletionsStreamin message.BufferResponse = false; Response baseResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken) .ConfigureAwait(false); - return Response.FromValue(new StreamingCompletions(baseResponse), baseResponse); + return new StreamingResponse( + baseResponse, + SseAsyncEnumerator.EnumerateFromSseStream( + baseResponse.ContentStream, + Completions.DeserializeCompletions, + cancellationToken)); } catch (Exception e) { @@ -420,7 +435,8 @@ public virtual async Task> GetChatCompletionsAsync( /// /// Service returned a non-success status code. /// The response returned from the service. - public virtual Response GetChatCompletionsStreaming( + [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")] + public virtual StreamingResponse GetChatCompletionsStreaming( string deploymentOrModelName, ChatCompletionsOptions chatCompletionsOptions, CancellationToken cancellationToken = default) @@ -449,7 +465,12 @@ public virtual Response GetChatCompletionsStreaming( context); message.BufferResponse = false; Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); + return new StreamingResponse( + baseResponse, + SseAsyncEnumerator.EnumerateFromSseStream( + baseResponse.ContentStream, + StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, + cancellationToken)); } catch (Exception e) { @@ -459,7 +480,8 @@ public virtual Response GetChatCompletionsStreaming( } /// - public virtual async Task> GetChatCompletionsStreamingAsync( + [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")] + public virtual async Task> GetChatCompletionsStreamingAsync( string deploymentOrModelName, ChatCompletionsOptions chatCompletionsOptions, CancellationToken cancellationToken = default) @@ -491,7 +513,12 @@ public virtual async Task> GetChatCompletions message, context, cancellationToken).ConfigureAwait(false); - return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse); + return new StreamingResponse( + baseResponse, + SseAsyncEnumerator.EnumerateFromSseStream( + baseResponse.ContentStream, + StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, + cancellationToken)); } catch (Exception e) { @@ -709,7 +736,7 @@ public virtual async Task> GetAudioTranscriptionAsy Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId)); Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions)); - using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription"); scope.Start(); audioTranscriptionOptions.InternalNonAzureModelName = deploymentId; @@ -721,7 +748,8 @@ public virtual async Task> GetAudioTranscriptionAsy try { using HttpMessage message = CreateGetAudioTranscriptionRequest(deploymentId, content, context); - rawResponse = await _pipeline.ProcessMessageAsync(message, context).ConfigureAwait(false); + rawResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken) + .ConfigureAwait(false); } catch (Exception e) { @@ -747,7 +775,7 @@ public virtual Response GetAudioTranscription(string deploym Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId)); Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions)); - using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription"); scope.Start(); audioTranscriptionOptions.InternalNonAzureModelName = deploymentId; @@ -759,7 +787,7 @@ public virtual Response GetAudioTranscription(string deploym try { using HttpMessage message = CreateGetAudioTranscriptionRequest(deploymentId, content, context); - rawResponse = _pipeline.ProcessMessage(message, context); + rawResponse = _pipeline.ProcessMessage(message, context, cancellationToken); } catch (Exception e) { @@ -788,7 +816,7 @@ public virtual async Task> GetAudioTranslationAsync(s // Custom code: merely linking the deployment ID (== model name) into the request body for non-Azure use audioTranslationOptions.InternalNonAzureModelName = deploymentId; - using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation"); scope.Start(); audioTranslationOptions.InternalNonAzureModelName = deploymentId; @@ -800,7 +828,8 @@ public virtual async Task> GetAudioTranslationAsync(s try { using HttpMessage message = CreateGetAudioTranslationRequest(deploymentId, content, context); - rawResponse = await _pipeline.ProcessMessageAsync(message, context).ConfigureAwait(false); + rawResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken) + .ConfigureAwait(false); } catch (Exception e) { @@ -826,7 +855,7 @@ public virtual Response GetAudioTranslation(string deploymentI Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId)); Argument.AssertNotNull(audioTranslationOptions, nameof(audioTranslationOptions)); - using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation"); + using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation"); scope.Start(); audioTranslationOptions.InternalNonAzureModelName = deploymentId; @@ -838,7 +867,7 @@ public virtual Response GetAudioTranslation(string deploymentI try { using HttpMessage message = CreateGetAudioTranslationRequest(deploymentId, content, context); - rawResponse = _pipeline.ProcessMessage(message, context); + rawResponse = _pipeline.ProcessMessage(message, context, cancellationToken); } catch (Exception e) { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs deleted file mode 100644 index ff6a184e318e3..0000000000000 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Threading; -using Azure.Core.Sse; - -namespace Azure.AI.OpenAI -{ - /// - /// Represents a streaming source of instances that allow an - /// application to incrementally use and display text and other information as it becomes available during a - /// generation operation. - /// - public class StreamingChatCompletions : IDisposable - { - private readonly Response _baseResponse; - private readonly SseReader _baseResponseReader; - private readonly List _cachedUpdates; - private bool _disposedValue; - - /// - /// Creates a new instance of based on a REST response stream. - /// - /// - /// The from which server-sent events will be consumed. - /// - internal StreamingChatCompletions(Response response) - { - _baseResponse = response; - _baseResponseReader = new SseReader(response.ContentStream); - _cachedUpdates = new(); - } - - /// - /// Creates a new instance of based on an existing set of - /// instances. Typically used for tests or mocking. - /// - /// - /// An existing collection of instances to use. Update enumeration - /// will yield these instances instead of instances derived from a response stream. - /// - internal StreamingChatCompletions(IEnumerable updates) - { - _cachedUpdates = new(updates); - } - - /// - /// Returns an asynchronous enumeration of instances as they - /// become available via the associated network response or other data source. - /// - /// - /// This collection may be iterated over using the "await foreach" statement. - /// - /// - /// - /// An optional used to end enumeration before it would normally complete. - /// - /// - /// Cancellation will immediately close any underlying network response stream and may consequently limit - /// incurred token generation and associated cost. - /// - /// - /// - /// An asynchronous enumeration of instances. - /// - /// - /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not - /// prefixed with the expected 'data:' tag. - /// - public async IAsyncEnumerable EnumerateChatUpdates( - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - if (_cachedUpdates.Any()) - { - foreach (StreamingChatCompletionsUpdate update in _cachedUpdates) - { - yield return update; - } - yield break; - } - - // Open clarification points: - // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? - // - What should happen if the stream is held too long and it times out? - // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? - while (!cancellationToken.IsCancellationRequested) - { - SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent == null) - { - _baseResponse?.ContentStream?.Dispose(); - break; - } - - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - throw new InvalidDataException(); - - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - _baseResponse?.ContentStream?.Dispose(); - break; - } - - JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); - foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate - in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement)) - { - _cachedUpdates.Add(streamingChatCompletionsUpdate); - yield return streamingChatCompletionsUpdate; - } - } - _baseResponse?.ContentStream?.Dispose(); - } - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _baseResponseReader?.Dispose(); - } - - _disposedValue = true; - } - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs index f28ddcf5f94fd..7232ca9845853 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs @@ -2,9 +2,18 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using Azure.Core.Sse; namespace Azure.AI.OpenAI { + /// + /// Represents an incremental update to a streamed Chat Completions response. + /// public partial class StreamingChatCompletionsUpdate { /// @@ -84,8 +93,8 @@ public partial class StreamingChatCompletionsUpdate /// /// /// As is the case for non-streaming , the content provided for function - /// arguments is not guaranteed to be well-formed JSON or to contain expected data. Use of function arguments - /// should validate before consumption. + /// arguments is not guaranteed to be well-formed JSON or to contain expected data. Callers should validate + /// function arguments before using them. /// /// public string FunctionArgumentsUpdate { get; } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs deleted file mode 100644 index a26baeb5631a7..0000000000000 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Threading; -using Azure.Core.Sse; - -namespace Azure.AI.OpenAI -{ - /// - /// Represents a streaming source of instances that allow an application to incrementally - /// use and display text and other information as it becomes available during a generation operation. - /// - public class StreamingCompletions : IDisposable - { - private readonly Response _baseResponse; - private readonly SseReader _baseResponseReader; - private readonly List _cachedCompletions; - private bool _disposedValue; - - /// - /// Creates a new instance of that will generate updates from a - /// service response stream. - /// - /// - /// The instance from which server-sent events will be parsed. - /// - internal StreamingCompletions(Response response) - { - _baseResponse = response; - _baseResponseReader = new SseReader(response.ContentStream); - _cachedCompletions = new(); - } - - /// - /// Creates a new instance of that will yield instances from a provided set - /// of instances. Typically used for tests and mocking. - /// - /// - /// An existing set of instances to use. - /// - internal StreamingCompletions(IEnumerable completions) - { - _cachedCompletions = new(completions); - } - - /// - /// Returns an asynchronous enumeration of instances as they - /// become available via the associated network response or other data source. - /// - /// - /// - /// This collection may be iterated over using the "await foreach" statement. - /// - /// - /// Note that, in contrast to , streamed share - /// the same underlying representation and object model as their non-streamed counterpart. The - /// property on a streamed instanced will only contain a - /// small number of new tokens and must be combined with other values to represent - /// a full, reconsistuted message. - /// - /// - /// - /// - /// An optional used to end enumeration before it would normally complete. - /// - /// - /// Cancellation will immediately close any underlying network response stream and may consequently limit - /// incurred token generation and associated cost. - /// - /// - /// - /// An asynchronous enumeration of instances. - /// - /// - /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not - /// prefixed with the expected 'data:' tag. - /// - public async IAsyncEnumerable EnumerateCompletions( - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - if (_cachedCompletions.Any()) - { - foreach (Completions completions in _cachedCompletions) - { - yield return completions; - } - yield break; - } - - // Open clarification points: - // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested? - // - What should happen if the stream is held too long and it times out? - // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently? - while (!cancellationToken.IsCancellationRequested) - { - SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent == null) - { - _baseResponse?.ContentStream?.Dispose(); - break; - } - - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - throw new InvalidDataException(); - - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - _baseResponse?.ContentStream?.Dispose(); - break; - } - - JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue); - Completions sseCompletions = Completions.DeserializeCompletions(sseMessageJson.RootElement); - _cachedCompletions.Add(sseCompletions); - yield return sseCompletions; - } - _baseResponse?.ContentStream?.Dispose(); - } - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _baseResponseReader?.Dispose(); - } - - _disposedValue = true; - } - } - } -} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs new file mode 100644 index 0000000000000..9c52fae14585d --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; + +namespace Azure.Core.Sse +{ + internal static class SseAsyncEnumerator + { + internal static async IAsyncEnumerable EnumerateFromSseStream( + Stream stream, + Func> multiElementDeserializer, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + try + { + using SseReader sseReader = new(stream); + while (!cancellationToken.IsCancellationRequested) + { + SseLine? sseEvent = await sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); + if (sseEvent is not null) + { + ReadOnlyMemory name = sseEvent.Value.FieldName; + if (!name.Span.SequenceEqual("data".AsSpan())) + { + throw new InvalidDataException(); + } + ReadOnlyMemory value = sseEvent.Value.FieldValue; + if (value.Span.SequenceEqual("[DONE]".AsSpan())) + { + break; + } + JsonDocument sseMessageJson = JsonDocument.Parse(value); + IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); + foreach (T item in newItems) + { + yield return item; + } + } + } + } + finally + { + // Always dispose the stream immediately once enumeration is complete for any reason + stream.Dispose(); + } + } + + internal static IAsyncEnumerable EnumerateFromSseStream( + Stream stream, + Func elementDeserializer, + CancellationToken cancellationToken = default) + => EnumerateFromSseStream( + stream, + (element) => new T[] { elementDeserializer.Invoke(element) }, + cancellationToken); + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseReader.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseReader.cs index 6e39bad0565db..80636253e3107 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseReader.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseReader.cs @@ -24,12 +24,12 @@ public SseReader(Stream stream) { while (true) { - var line = TryReadLine(); + SseLine? line = TryReadLine(); if (line == null) return null; if (line.Value.IsEmpty) throw new InvalidDataException("event expected."); - var empty = TryReadLine(); + SseLine? empty = TryReadLine(); if (empty != null && !empty.Value.IsEmpty) throw new NotSupportedException("Multi-filed events not supported."); if (!line.Value.IsComment) @@ -42,12 +42,12 @@ public SseReader(Stream stream) { while (true) { - var line = await TryReadLineAsync().ConfigureAwait(false); + SseLine? line = await TryReadLineAsync().ConfigureAwait(false); if (line == null) return null; if (line.Value.IsEmpty) throw new InvalidDataException("event expected."); - var empty = await TryReadLineAsync().ConfigureAwait(false); + SseLine? empty = await TryReadLineAsync().ConfigureAwait(false); if (empty != null && !empty.Value.IsEmpty) throw new NotSupportedException("Multi-filed events not supported."); if (!line.Value.IsComment) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs new file mode 100644 index 0000000000000..a91d8a3ec9697 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Azure; + +/// +/// Represents an operation response with streaming content that can be deserialized and enumerated while the response +/// is still being received. +/// +/// The data type representative of distinct, streamable items. +public class StreamingResponse + : IDisposable + , IAsyncEnumerable +{ + private Response _rawResponse { get; } + private IAsyncEnumerable _asyncEnumerableSource { get; } + private bool _disposedValue { get; set; } + + internal StreamingResponse(Response rawResponse, IAsyncEnumerable asyncEnumerableSource) + { + _rawResponse = rawResponse; + _asyncEnumerableSource = asyncEnumerableSource; + } + + /// + /// Gets the underlying instance that this may enumerate + /// over. + /// + /// The instance attached to this . + public Response GetRawResponse() => _rawResponse; + + /// + /// Gets the asynchronously enumerable collection of distinct, streamable items in the response. + /// + /// + /// The return value of this method may be used with the "await foreach" statement. + /// + /// As explicitly implements , callers may + /// enumerate a instance directly instead of calling this method. + /// + /// + /// + public IAsyncEnumerable EnumerateValues() => this; + + /// + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _rawResponse?.Dispose(); + } + _disposedValue = true; + } + } + + IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) + => _asyncEnumerableSource.GetAsyncEnumerator(cancellationToken); +} diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs index 96b1c56c7bff4..07be71f1a2e7d 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs @@ -47,7 +47,7 @@ public async Task BasicSearchExtensionWorks( { Endpoint = "https://openaisdktestsearch.search.windows.net", IndexName = "openai-test-index-carbon-wiki", - Key = GetCognitiveSearchApiKey().Key, + GetCognitiveSearchApiKey().Key, }, new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }), }, @@ -116,7 +116,7 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic { Endpoint = "https://openaisdktestsearch.search.windows.net", IndexName = "openai-test-index-carbon-wiki", - Key = GetCognitiveSearchApiKey().Key, + GetCognitiveSearchApiKey().Key, }, new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }), }, @@ -124,19 +124,16 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic }, }; - Response response = await client.GetChatCompletionsStreamingAsync( - deploymentOrModelName, - requestOptions); + using StreamingResponse response + = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); - Assert.That(response.Value, Is.Not.Null); - - using StreamingChatCompletions streamingChatCompletions = response.Value; + Assert.That(response.GetRawResponse(), Is.Not.Null); ChatRole? streamedRole = null; IEnumerable azureContextMessages = null; StringBuilder contentBuilder = new(); - await foreach (StreamingChatCompletionsUpdate chatUpdate in response.Value.EnumerateChatUpdates()) + await foreach (StreamingChatCompletionsUpdate chatUpdate in response) { if (chatUpdate.Role.HasValue) { diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs index fd7b345ef854a..a78a1e3a63bd5 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs @@ -2,11 +2,9 @@ // Licensed under the MIT License. using System; -using System.Linq; using System.Text; using System.Text.Json; using System.Threading.Tasks; -using Azure.Core.Pipeline; using Azure.Core.TestFramework; using NUnit.Framework; @@ -103,17 +101,15 @@ public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTa MaxTokens = 512, }; - Response response + StreamingResponse response = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); - using StreamingChatCompletions streamingChatCompletions = response.Value; - ChatRole? streamedRole = default; string functionName = null; StringBuilder argumentsBuilder = new(); - await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates()) + await foreach (StreamingChatCompletionsUpdate chatUpdate in response) { if (chatUpdate.Role.HasValue) { diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs index 30401cad8ce50..eff9e748f1368 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs @@ -188,25 +188,34 @@ public async Task TestStreamingChatCompletions() { const string expectedId = "expected-id-value"; - StreamingChatCompletionsUpdate firstUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( - expectedId, - DateTime.Now, - role: ChatRole.Assistant, - contentUpdate: "hello"); - StreamingChatCompletionsUpdate secondUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( - expectedId, - DateTime.Now, - contentUpdate: " world"); - StreamingChatCompletionsUpdate thirdUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( - expectedId, - DateTime.Now, - finishReason: CompletionsFinishReason.Stopped); + StreamingChatCompletionsUpdate[] updates = new[] + { + AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + role: ChatRole.Assistant, + contentUpdate: "hello"), + AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + contentUpdate: " world"), + AzureOpenAIModelFactory.StreamingChatCompletionsUpdate( + expectedId, + DateTime.Now, + finishReason: CompletionsFinishReason.Stopped), + }; - using StreamingChatCompletions streamingChatCompletions = AzureOpenAIModelFactory.StreamingChatCompletions( - new[] { firstUpdate, secondUpdate, thirdUpdate }); + async IAsyncEnumerable EnumerateMockUpdates() + { + foreach (StreamingChatCompletionsUpdate update in updates) + { + yield return update; + } + await Task.Delay(0); + } StringBuilder contentBuilder = new(); - await foreach (StreamingChatCompletionsUpdate update in streamingChatCompletions.EnumerateChatUpdates()) + await foreach (StreamingChatCompletionsUpdate update in EnumerateMockUpdates()) { Assert.That(update.Id == expectedId); Assert.That(update.Created > new DateTimeOffset(new DateTime(2023, 1, 1))); diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index 951580c9d7c3f..bd26bde2ebb4d 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -243,18 +243,16 @@ public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarg }, MaxTokens = 512, }; - Response streamingResponse + StreamingResponse response = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions); - Assert.That(streamingResponse, Is.Not.Null); - using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value; - Assert.That(streamingChatCompletions, Is.InstanceOf()); + Assert.That(response, Is.Not.Null); - StringBuilder contentBuilder = new StringBuilder(); + StringBuilder contentBuilder = new(); bool gotRole = false; bool gotRequestContentFilterResults = false; bool gotResponseContentFilterResults = false; - await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates()) + await foreach (StreamingChatCompletionsUpdate chatUpdate in response) { Assert.That(chatUpdate, Is.Not.Null); @@ -404,21 +402,16 @@ public async Task StreamingCompletions(OpenAIClientServiceTarget serviceTarget) LogProbabilityCount = 1, }; - Response response = await client.GetCompletionsStreamingAsync( + using StreamingResponse response = await client.GetCompletionsStreamingAsync( deploymentOrModelName, requestOptions); Assert.That(response, Is.Not.Null); - // StreamingCompletions implements IDisposable; capturing the .Value field of `response` with a `using` - // statement is unusual but properly ensures that `.Dispose()` will be called, as `Response` does *not* - // implement IDisposable or otherwise ensure that an `IDisposable` underlying `.Value` is disposed. - using StreamingCompletions responseValue = response.Value; - Dictionary promptFilterResultsByPromptIndex = new(); Dictionary finishReasonsByChoiceIndex = new(); Dictionary textBuildersByChoiceIndex = new(); - await foreach (Completions streamingCompletions in responseValue.EnumerateCompletions()) + await foreach (Completions streamingCompletions in response) { if (streamingCompletions.PromptFilterResults is not null) { @@ -495,18 +488,6 @@ public void JsonTypeSerialization() Assert.That(messageFromUtf8Bytes.Content, Is.EqualTo(originalMessage.Content)); } - // Lightweight reimplementation of .NET 7 .ToBlockingEnumerable().ToList() - private static async Task> GetBlockingListFromIAsyncEnumerable( - IAsyncEnumerable asyncValues) - { - List result = new(); - await foreach (T asyncValue in asyncValues) - { - result.Add(asyncValue); - } - return result; - } - private void AssertExpectedPromptFilterResults( IReadOnlyList promptFilterResults, OpenAIClientServiceTarget serviceTarget, diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs index 8bee03f306da5..7f78872e44f11 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs @@ -2,9 +2,7 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.Threading.Tasks; -using Azure.Identity; using NUnit.Framework; namespace Azure.AI.OpenAI.Tests.Samples @@ -29,13 +27,9 @@ public async Task StreamingChatWithNonAzureOpenAI() } }; - Response response = await client.GetChatCompletionsStreamingAsync( + await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming( deploymentOrModelName: "gpt-3.5-turbo", - chatCompletionsOptions); - using StreamingChatCompletions streamingChatCompletions = response.Value; - - await foreach (StreamingChatCompletionsUpdate chatUpdate - in streamingChatCompletions.EnumerateChatUpdates()) + chatCompletionsOptions)) { if (chatUpdate.Role.HasValue) { From 1a6a3a3e2e35d37fd10e6522f5bef476e34d8025 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Fri, 20 Oct 2023 15:17:28 -0700 Subject: [PATCH 11/19] use delegate resolver for stronger response/enum connection --- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 1 + .../src/Custom/OpenAIClient.cs | 34 ++++++++++--------- .../src/Helpers/StreamingResponse.cs | 28 +++++++++++++-- 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index c74ba62e240da..180f775f21a29 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -3,6 +3,7 @@ namespace Azure public partial class StreamingResponse : System.Collections.Generic.IAsyncEnumerable, System.IDisposable { internal StreamingResponse() { } + public static Azure.StreamingResponse CreateFromResponse(Azure.Response response, System.Func> asyncEnumerableProcessor) { throw null; } public void Dispose() { } protected virtual void Dispose(bool disposing) { } public System.Collections.Generic.IAsyncEnumerable EnumerateValues() { throw null; } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs index 15fcfc323d50a..532617fa58945 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs @@ -272,10 +272,10 @@ public virtual StreamingResponse GetCompletionsStreaming( context); message.BufferResponse = false; Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return new StreamingResponse( + return StreamingResponse.CreateFromResponse( baseResponse, - SseAsyncEnumerator.EnumerateFromSseStream( - baseResponse.ContentStream, + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + responseForEnumeration.ContentStream, Completions.DeserializeCompletions, cancellationToken)); } @@ -316,10 +316,10 @@ public virtual async Task> GetCompletionsStreamin message.BufferResponse = false; Response baseResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken) .ConfigureAwait(false); - return new StreamingResponse( + return StreamingResponse.CreateFromResponse( baseResponse, - SseAsyncEnumerator.EnumerateFromSseStream( - baseResponse.ContentStream, + (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( + responseForEnumeration.ContentStream, Completions.DeserializeCompletions, cancellationToken)); } @@ -465,12 +465,13 @@ public virtual StreamingResponse GetChatCompleti context); message.BufferResponse = false; Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken); - return new StreamingResponse( + return StreamingResponse.CreateFromResponse( baseResponse, - SseAsyncEnumerator.EnumerateFromSseStream( - baseResponse.ContentStream, - StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, - cancellationToken)); + (responseForEnumeration) + => SseAsyncEnumerator.EnumerateFromSseStream( + responseForEnumeration.ContentStream, + StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, + cancellationToken)); } catch (Exception e) { @@ -513,12 +514,13 @@ public virtual async Task> Get message, context, cancellationToken).ConfigureAwait(false); - return new StreamingResponse( + return StreamingResponse.CreateFromResponse( baseResponse, - SseAsyncEnumerator.EnumerateFromSseStream( - baseResponse.ContentStream, - StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, - cancellationToken)); + (responseForEnumeration) + => SseAsyncEnumerator.EnumerateFromSseStream( + responseForEnumeration.ContentStream, + StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates, + cancellationToken)); } catch (Exception e) { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs index a91d8a3ec9697..f0122f425d4d1 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs @@ -20,10 +20,34 @@ public class StreamingResponse private IAsyncEnumerable _asyncEnumerableSource { get; } private bool _disposedValue { get; set; } - internal StreamingResponse(Response rawResponse, IAsyncEnumerable asyncEnumerableSource) + private StreamingResponse() { } + + private StreamingResponse( + Response rawResponse, + Func> asyncEnumerableProcessor) { _rawResponse = rawResponse; - _asyncEnumerableSource = asyncEnumerableSource; + _asyncEnumerableSource = asyncEnumerableProcessor.Invoke(rawResponse); + } + + /// + /// Creates a new instance of using the provided underlying HTTP response. The + /// provided function will be used to resolve the response into an asynchronous enumeration of streamed response + /// items. + /// + /// The HTTP response. + /// + /// The function that will resolve the provided response into an IAsyncEnumerable. + /// + /// + /// A new instance of that will be capable of asynchronous enumeration of + /// items from the HTTP response. + /// + public static StreamingResponse CreateFromResponse( + Response response, + Func> asyncEnumerableProcessor) + { + return new(response, asyncEnumerableProcessor); } /// From 42066ad1fc0809d57baa6566193a7738185b9d4f Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 23 Oct 2023 13:29:58 -0700 Subject: [PATCH 12/19] add a snippet for streaming functions --- sdk/openai/Azure.AI.OpenAI/README.md | 38 +++++++++++++++++++ .../tests/Samples/Sample07_ChatFunctions.cs | 36 +++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index 0fd3e91de72be..d596f1318f37c 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -335,6 +335,44 @@ if (responseChoice.FinishReason == CompletionsFinishReason.FunctionCall) } ``` +When using streaming, capture streaming response components as they arrive and accumulate streaming function arguments +in the same manner used for streaming content. Then, in the place of using the `ChatMessage` from the non-streaming +response, instead add a new `ChatMessage` instance for history, created from the streamed information. + +```C# Snippet::ChatFunctions::StreamingFunctions +string functionName = null; +StringBuilder contentBuilder = new(); +StringBuilder functionArgumentsBuilder = new(); +ChatRole streamedRole = default; +CompletionsFinishReason finishReason = default; + +using StreamingResponse streamingResponse + = client.GetChatCompletionsStreaming( + "gpt-35-turbo-0613", + chatCompletionsOptions); +await foreach (StreamingChatCompletionsUpdate update in streamingResponse) +{ + contentBuilder.Append(update.ContentUpdate); + functionName ??= update.FunctionName; + functionArgumentsBuilder.Append(update.FunctionArgumentsUpdate); + streamedRole = update.Role ?? default; + finishReason = update.FinishReason ?? default; +} + +if (finishReason == CompletionsFinishReason.FunctionCall) +{ + string lastContent = contentBuilder.ToString(); + string unvalidatedArguments = functionArgumentsBuilder.ToString(); + ChatMessage chatMessageForHistory = new(streamedRole, lastContent) + { + FunctionCall = new(functionName, unvalidatedArguments), + }; + conversationMessages.Add(chatMessageForHistory); + + // Handle from here just like the non-streaming case +} +``` + ### Use your own data with Azure OpenAI The use your own data feature is unique to Azure OpenAI and won't work with a client configured to use the non-Azure service. diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample07_ChatFunctions.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample07_ChatFunctions.cs index 7b747adc2dfdc..e018ec092d69e 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample07_ChatFunctions.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample07_ChatFunctions.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Generic; -using System.IO; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Azure.Identity; @@ -99,6 +99,40 @@ public async Task ChatFunctions() } } #endregion + + #region Snippet::ChatFunctions::StreamingFunctions + string functionName = null; + StringBuilder contentBuilder = new(); + StringBuilder functionArgumentsBuilder = new(); + ChatRole streamedRole = default; + CompletionsFinishReason finishReason = default; + + using StreamingResponse streamingResponse + = client.GetChatCompletionsStreaming( + "gpt-35-turbo-0613", + chatCompletionsOptions); + await foreach (StreamingChatCompletionsUpdate update in streamingResponse) + { + contentBuilder.Append(update.ContentUpdate); + functionName ??= update.FunctionName; + functionArgumentsBuilder.Append(update.FunctionArgumentsUpdate); + streamedRole = update.Role ?? default; + finishReason = update.FinishReason ?? default; + } + + if (finishReason == CompletionsFinishReason.FunctionCall) + { + string lastContent = contentBuilder.ToString(); + string unvalidatedArguments = functionArgumentsBuilder.ToString(); + ChatMessage chatMessageForHistory = new(streamedRole, lastContent) + { + FunctionCall = new(functionName, unvalidatedArguments), + }; + conversationMessages.Add(chatMessageForHistory); + + // Handle from here just like the non-streaming case + } + #endregion } } } From a5312a952a87a4245e8e58d37e6e863bb2d5d2dd Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 23 Oct 2023 14:41:03 -0700 Subject: [PATCH 13/19] also add a snippet for streaming with multiple choices --- sdk/openai/Azure.AI.OpenAI/README.md | 39 +++++++++++++++++++ .../tests/Samples/Sample04_StreamingChat.cs | 39 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index d596f1318f37c..be7896b3d93a3 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -221,6 +221,39 @@ await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatComple } ``` +When explicitly requesting more than one `Choice` while streaming, use the `ChoiceIndex` property on +`StreamingChatCompletionsUpdate` to determine which `Choice` each update corresponds to. + +```C# Snippet:StreamChatMessagesWithMultipleChoices +// A ChoiceCount > 1 will feature multiple, parallel, independent text generations arriving on the +// same response. This may be useful when choosing between multiple candidates for a single request. +var chatCompletionsOptions = new ChatCompletionsOptions() +{ + Messages = { new ChatMessage(ChatRole.User, "Write a limerick about bananas.") }, + ChoiceCount = 4 +}; + +await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming( + deploymentOrModelName: "gpt-3.5-turbo", + chatCompletionsOptions)) +{ + // Choice-specific information like Role and ContentUpdate will also provide a ChoiceIndex that allows + // StreamingChatCompletionsUpdate data for independent choices to be appropriately separated. + if (chatUpdate.ChoiceIndex.HasValue) + { + int choiceIndex = chatUpdate.ChoiceIndex.Value; + if (chatUpdate.Role.HasValue) + { + textBoxes[choiceIndex].Text += $"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: "; + } + if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate)) + { + textBoxes[choiceIndex].Text += chatUpdate.ContentUpdate; + } + } +} +``` + ### Use Chat Functions Chat Functions allow a caller of Chat Completions to define capabilities that the model can use to extend its @@ -373,6 +406,12 @@ if (finishReason == CompletionsFinishReason.FunctionCall) } ``` +Please note: while streamed function information (name, arguments) may be evaluated as it arrives, it should not be +considered complete or confirmed until the `FinishReason` of `FunctionCall` is received. It may be appropriate to make +best-effort attempts at "warm-up" or other speculative preparation based on a function name or particular key/value +appearing in the accumulated, partial JSON arguments, but no strong assumptions about validity, ordering, or other +details should be evaluated until the arguments are fully available and confirmed via `FinishReason`. + ### Use your own data with Azure OpenAI The use your own data feature is unique to Azure OpenAI and won't work with a client configured to use the non-Azure service. diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs index 7f78872e44f11..13d389ffa45c0 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs @@ -42,5 +42,44 @@ public async Task StreamingChatWithNonAzureOpenAI() } #endregion } + + [Test] + [Ignore("Only verifying that the sample builds")] + public async Task StreamingChatWithMultipleChoices() + { + string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com"; + var client = new OpenAIClient(nonAzureOpenAIApiKey, new OpenAIClientOptions()); + (object, string Text)[] textBoxes = new (object, string)[4]; + + #region Snippet:StreamChatMessagesWithMultipleChoices + // A ChoiceCount > 1 will feature multiple, parallel, independent text generations arriving on the + // same response. This may be useful when choosing between multiple candidates for a single request. + var chatCompletionsOptions = new ChatCompletionsOptions() + { + Messages = { new ChatMessage(ChatRole.User, "Write a limerick about bananas.") }, + ChoiceCount = 4 + }; + + await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming( + deploymentOrModelName: "gpt-3.5-turbo", + chatCompletionsOptions)) + { + // Choice-specific information like Role and ContentUpdate will also provide a ChoiceIndex that allows + // StreamingChatCompletionsUpdate data for independent choices to be appropriately separated. + if (chatUpdate.ChoiceIndex.HasValue) + { + int choiceIndex = chatUpdate.ChoiceIndex.Value; + if (chatUpdate.Role.HasValue) + { + textBoxes[choiceIndex].Text += $"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: "; + } + if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate)) + { + textBoxes[choiceIndex].Text += chatUpdate.ContentUpdate; + } + } + } + #endregion + } } } From 046e026e31d20f2e697529a14c901be4313c3c32 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 23 Oct 2023 14:58:16 -0700 Subject: [PATCH 14/19] speculative CHANGELOG update --- sdk/openai/Azure.AI.OpenAI/CHANGELOG.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md index fd55414684f89..8df3b19f05626 100644 --- a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md +++ b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md @@ -6,6 +6,21 @@ ### Breaking Changes +- Streaming Completions and Streaming Chat Completions have been significantly updated to use simpler, shallower usage + patterns and data representations. The goal of these changes is to make streaming much easier to consume in common + cases while still retaining full functionality in more complex ones (e.g. with multiple choices requested). + - A new `StreamingResponse` type is introduced that implicitly exposes an `IAsyncEnumerable` derived from + the underlying response. + - `OpenAI.GetCompletionsStreaming()` now returns a `StreamingResponse` that may be directly + enumerated over. `StreamingCompletions`, `StreamingChoice`, and the corresponding methods are removed. + - Because Chat Completions use a distinct structure for their streaming response messages, a new + `StreamingChatCompletionsUpdate` type is introduced that encapsulates this update data. + - Correspondingly, `OpenAI.GetChatCompletionsStreaming()` now returns a + `StreamingResponse` that may be enumerated over directly. + `StreamingChatCompletions`, `StreamingChatChoice`, and related methods are removed. + - For more information, please see [the related pull request description]() as well as the updated snippets in the + project README. + ### Bugs Fixed ### Other Changes From 56710f609c1834ff00a8a4649f71f83f8407269a Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Wed, 25 Oct 2023 14:24:03 -0700 Subject: [PATCH 15/19] basic, standalone test coverage for StreamingResponse --- .../tests/StreamingResponseTests.cs | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 sdk/openai/Azure.AI.OpenAI/tests/StreamingResponseTests.cs diff --git a/sdk/openai/Azure.AI.OpenAI/tests/StreamingResponseTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/StreamingResponseTests.cs new file mode 100644 index 0000000000000..a1d59a9361954 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/tests/StreamingResponseTests.cs @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using NUnit.Framework; + +namespace Azure.AI.OpenAI.Tests; + +[TestFixture] +public class StreamingResponseTests +{ + [Test] + public async Task SingleEnumeration() + { + byte[] responseCharacters = Encoding.UTF8.GetBytes("abcde"); + using Stream responseStream = new MemoryStream(responseCharacters); + + StreamingResponse response = StreamingResponse.CreateFromResponse( + response: new TestResponse(responseStream), + asyncEnumerableProcessor: (r) => EnumerateCharactersFromResponse(r)); + + int index = 0; + await foreach (char characterInResponse in response) + { + Assert.That(index < responseCharacters.Length); + Assert.That((char)responseCharacters[index] == characterInResponse); + index++; + } + Assert.That(index == responseCharacters.Length); + } + + [Test] + public async Task CancellationViaParameter() + { + byte[] responseCharacters = Encoding.UTF8.GetBytes("abcdefg"); + using Stream responseStream = new MemoryStream(responseCharacters); + + var cancelSource = new CancellationTokenSource(); + + StreamingResponse response = StreamingResponse.CreateFromResponse( + response: new TestResponse(responseStream), + asyncEnumerableProcessor: (r) => EnumerateCharactersFromResponse(r, cancelSource.Token)); + + int index = 0; + await foreach (char characterInResponse in response) + { + Assert.That(index < responseCharacters.Length); + Assert.That((char)responseCharacters[index] == characterInResponse); + if (index == 2) + { + cancelSource.Cancel(); + } + index++; + } + Assert.That(index == 3); + } + + [Test] + public async Task CancellationViaWithCancellation() + { + byte[] responseCharacters = Encoding.UTF8.GetBytes("abcdefg"); + using Stream responseStream = new MemoryStream(responseCharacters); + + async static IAsyncEnumerable EnumerateViaWithCancellation( + Response response, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + ConfiguredCancelableAsyncEnumerable enumerable + = EnumerateCharactersFromResponse(response).WithCancellation(cancellationToken); + await foreach (var character in enumerable) + { + yield return character; + } + } + + var cancelSource = new CancellationTokenSource(); + + StreamingResponse response = StreamingResponse.CreateFromResponse( + response: new TestResponse(responseStream), + asyncEnumerableProcessor: (r) => EnumerateViaWithCancellation(r, cancelSource.Token)); + + int index = 0; + await foreach (char characterInResponse in response) + { + Assert.That(index < responseCharacters.Length); + Assert.That((char)responseCharacters[index] == characterInResponse); + if (index == 2) + { + cancelSource.Cancel(); + } + index++; + } + Assert.That(index == 3); + } + + [Test] + public async Task CancellationWhenThrowing() + { + byte[] responseCharacters = Encoding.UTF8.GetBytes("abcdefg"); + using Stream responseStream = new MemoryStream(responseCharacters); + + var cancelSource = new CancellationTokenSource(); + + StreamingResponse response = StreamingResponse.CreateFromResponse( + response: new TestResponse(responseStream), + asyncEnumerableProcessor: (r) => EnumerateCharactersFromResponse(r, cancelSource.Token)); + + int index = 0; + await foreach (char characterInResponse in response) + { + Assert.That(index < responseCharacters.Length); + Assert.That((char)responseCharacters[index] == characterInResponse); + if (index == 2) + { + cancelSource.Cancel(); + } + index++; + } + Assert.That(index == 3); + + Exception exception = null; + try + { + cancelSource.Token.ThrowIfCancellationRequested(); + } + catch (Exception ex) + { + exception = ex; + } + + Assert.That(exception, Is.InstanceOf()); + } + + private static async IAsyncEnumerable EnumerateCharactersFromResponse( + Response responseToEnumerate, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (responseToEnumerate?.ContentStream?.CanRead == true) + { + byte[] buffer = new byte[1]; + while ( + !cancellationToken.IsCancellationRequested + && (await responseToEnumerate.ContentStream.ReadAsync(buffer, 0, 1)) == 1) + { + yield return (char)buffer[0]; + } + } + } + + private class TestResponse : Response + { + public override int Status => throw new NotImplementedException(); + + public override string ReasonPhrase => throw new NotImplementedException(); + + public override Stream ContentStream + { + get => _contentStream; + set => throw new NotImplementedException(); + } + + public override string ClientRequestId + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + private Stream _contentStream { get; } + + public TestResponse(Stream stream) + { + _contentStream = stream; + } + + public override void Dispose() => _contentStream?.Dispose(); + + protected override bool ContainsHeader(string name) => throw new NotImplementedException(); + + protected override IEnumerable EnumerateHeaders() + => throw new NotImplementedException(); + + protected override bool TryGetHeader(string name, out string value) + => throw new NotImplementedException(); + + protected override bool TryGetHeaderValues(string name, out IEnumerable values) + => throw new NotImplementedException(); + } +} From 6df8f6e3c836ffbda38e46dfe0b86ff99220e943 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Fri, 27 Oct 2023 14:33:37 -0700 Subject: [PATCH 16/19] feedback: keep StreamingResponse in Azure.AI.OpenAI --- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 31 +++++++++---------- .../src/Helpers/StreamingResponse.cs | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index 180f775f21a29..884e71f634c5b 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -1,16 +1,3 @@ -namespace Azure -{ - public partial class StreamingResponse : System.Collections.Generic.IAsyncEnumerable, System.IDisposable - { - internal StreamingResponse() { } - public static Azure.StreamingResponse CreateFromResponse(Azure.Response response, System.Func> asyncEnumerableProcessor) { throw null; } - public void Dispose() { } - protected virtual void Dispose(bool disposing) { } - public System.Collections.Generic.IAsyncEnumerable EnumerateValues() { throw null; } - public Azure.Response GetRawResponse() { throw null; } - System.Collections.Generic.IAsyncEnumerator System.Collections.Generic.IAsyncEnumerable.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; } - } -} namespace Azure.AI.OpenAI { public partial class AudioTranscription @@ -490,14 +477,14 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent public virtual System.Threading.Tasks.Task> GetAudioTranslationAsync(string deploymentId, Azure.AI.OpenAI.AudioTranslationOptions audioTranslationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.StreamingResponse GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.AI.OpenAI.StreamingResponse GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.StreamingResponse GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.AI.OpenAI.StreamingResponse GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetEmbeddings(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetEmbeddingsAsync(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetImageGenerations(Azure.AI.OpenAI.ImageGenerationOptions imageGenerationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -536,6 +523,16 @@ internal StreamingChatCompletionsUpdate() { } public string Id { get { throw null; } } public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } } } + public partial class StreamingResponse : System.Collections.Generic.IAsyncEnumerable, System.IDisposable + { + internal StreamingResponse() { } + public static Azure.AI.OpenAI.StreamingResponse CreateFromResponse(Azure.Response response, System.Func> asyncEnumerableProcessor) { throw null; } + public void Dispose() { } + protected virtual void Dispose(bool disposing) { } + public System.Collections.Generic.IAsyncEnumerable EnumerateValues() { throw null; } + public Azure.Response GetRawResponse() { throw null; } + System.Collections.Generic.IAsyncEnumerator System.Collections.Generic.IAsyncEnumerable.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; } + } } namespace Microsoft.Extensions.Azure { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs index f0122f425d4d1..5ba9fcb5aff47 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; using System.Threading; -namespace Azure; +namespace Azure.AI.OpenAI; /// /// Represents an operation response with streaming content that can be deserialized and enumerated while the response From 041197844f01b0fe92fdf737f557199de2768259 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Fri, 27 Oct 2023 17:17:42 -0700 Subject: [PATCH 17/19] address missing 'using' on JsonDocument --- sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs index 9c52fae14585d..9c32d20ee4e6f 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs @@ -35,7 +35,7 @@ internal static async IAsyncEnumerable EnumerateFromSseStream( { break; } - JsonDocument sseMessageJson = JsonDocument.Parse(value); + using JsonDocument sseMessageJson = JsonDocument.Parse(value); IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); foreach (T item in newItems) { From b8472ccbe27ca95be6191698c48ceea14df504a4 Mon Sep 17 00:00:00 2001 From: Travis Wilson <35748617+trrwilson@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:26:22 -0700 Subject: [PATCH 18/19] tie up broken link from changelog --- sdk/openai/Azure.AI.OpenAI/CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md index 8df3b19f05626..2900f27c59295 100644 --- a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md +++ b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md @@ -18,8 +18,9 @@ - Correspondingly, `OpenAI.GetChatCompletionsStreaming()` now returns a `StreamingResponse` that may be enumerated over directly. `StreamingChatCompletions`, `StreamingChatChoice`, and related methods are removed. - - For more information, please see [the related pull request description]() as well as the updated snippets in the - project README. + - For more information, please see + [the related pull request description](https://github.com/Azure/azure-sdk-for-net/pull/39347) as well as the + updated snippets in the project README. ### Bugs Fixed From 05b8a4ca4fc201531d54c72a3f19a294eeaaf99a Mon Sep 17 00:00:00 2001 From: Travis Wilson <35748617+trrwilson@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:41:40 -0700 Subject: [PATCH 19/19] Post-merge: export-api and update-snippets --- sdk/openai/Azure.AI.OpenAI/README.md | 12 ++++-------- .../api/Azure.AI.OpenAI.netstandard2.0.cs | 8 ++++---- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index 52dc7380b95e3..646a327e73513 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -234,9 +234,8 @@ var chatCompletionsOptions = new ChatCompletionsOptions() ChoiceCount = 4 }; -await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming( - deploymentOrModelName: "gpt-3.5-turbo", - chatCompletionsOptions)) +await foreach (StreamingChatCompletionsUpdate chatUpdate + in client.GetChatCompletionsStreaming(chatCompletionsOptions)) { // Choice-specific information like Role and ContentUpdate will also provide a ChoiceIndex that allows // StreamingChatCompletionsUpdate data for independent choices to be appropriately separated. @@ -381,11 +380,8 @@ StringBuilder functionArgumentsBuilder = new(); ChatRole streamedRole = default; CompletionsFinishReason finishReason = default; -using StreamingResponse streamingResponse - = client.GetChatCompletionsStreaming( - "gpt-35-turbo-0613", - chatCompletionsOptions); -await foreach (StreamingChatCompletionsUpdate update in streamingResponse) +await foreach (StreamingChatCompletionsUpdate update + in client.GetChatCompletionsStreaming(chatCompletionsOptions)) { contentBuilder.Append(update.ContentUpdate); functionName ??= update.FunctionName; diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index 2f178949dd324..6e9bb46de543a 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -481,12 +481,12 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent public virtual System.Threading.Tasks.Task> GetAudioTranslationAsync(Azure.AI.OpenAI.AudioTranslationOptions audioTranslationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetChatCompletions(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.Response GetChatCompletionsStreaming(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.AI.OpenAI.StreamingResponse GetChatCompletionsStreaming(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetCompletions(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetCompletionsAsync(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual Azure.Response GetCompletionsStreaming(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual Azure.AI.OpenAI.StreamingResponse GetCompletionsStreaming(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetEmbeddings(Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual System.Threading.Tasks.Task> GetEmbeddingsAsync(Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public virtual Azure.Response GetImageGenerations(Azure.AI.OpenAI.ImageGenerationOptions imageGenerationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }