From d591c819b0e090c3a8b2e5394b1c7175f32fe4b2 Mon Sep 17 00:00:00 2001
From: Travis Wilson <35748617+trrwilson@users.noreply.github.com>
Date: Tue, 17 Oct 2023 18:53:15 -0700
Subject: [PATCH 01/19] DRAFT: unvetted proposal for flattened streaming
---
.../AzureChatExtensionsMessageContext.cs | 33 ++++
.../src/Custom/OpenAIClient.cs | 100 +++++++++++
.../src/Custom/StreamingChatCompletions2.cs | 78 +++++++++
...mingChatCompletionsUpdate.Serialization.cs | 164 ++++++++++++++++++
.../Custom/StreamingChatCompletionsUpdate.cs | 59 +++++++
.../tests/OpenAIInferenceTests.cs | 73 ++++++++
6 files changed, 507 insertions(+)
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs
new file mode 100644
index 0000000000000..1e903ab1326ba
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureChatExtensionsMessageContext.cs
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+//
+
+#nullable disable
+
+using System.Collections.Generic;
+using Azure.Core;
+
+namespace Azure.AI.OpenAI
+{
+ ///
+ /// A representation of the additional context information available when Azure OpenAI chat extensions are involved
+ /// in the generation of a corresponding chat completions response. This context information is only populated when
+ /// using an Azure OpenAI request configured to use a matching extension.
+ ///
+ public partial class AzureChatExtensionsMessageContext
+ {
+ public ContentFilterResults RequestContentFilterResults { get; internal set; }
+ public ContentFilterResults ResponseContentFilterResults { get; internal set; }
+
+ internal AzureChatExtensionsMessageContext(
+ IList messages,
+ ContentFilterResults requestContentFilterResults,
+ ContentFilterResults responseContentFilterResults)
+ : this(messages)
+ {
+ RequestContentFilterResults = requestContentFilterResults;
+ ResponseContentFilterResults = responseContentFilterResults;
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
index e46691d42ae88..d55ad4462a585 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
@@ -500,6 +500,106 @@ public virtual async Task> GetChatCompletions
}
}
+ ///
+ /// Begin a chat completions request and get an object that can stream response data as it becomes
+ /// available.
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// the chat completions options for this chat completions request.
+ ///
+ ///
+ /// a cancellation token that can be used to cancel the initial request or ongoing streaming operation.
+ ///
+ ///
+ /// or is null.
+ ///
+ /// Service returned a non-success status code.
+ /// The response returned from the service.
+ public virtual Response GetChatCompletionsStreaming2(
+ string deploymentOrModelName,
+ ChatCompletionsOptions chatCompletionsOptions,
+ CancellationToken cancellationToken = default)
+ {
+ Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
+ Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
+
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2");
+ scope.Start();
+
+ chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
+ chatCompletionsOptions.InternalShouldStreamResponse = true;
+
+ string operationPath = GetOperationPath(chatCompletionsOptions);
+
+ RequestContent content = chatCompletionsOptions.ToRequestContent();
+ RequestContext context = FromCancellationToken(cancellationToken);
+
+ try
+ {
+ // Response value object takes IDisposable ownership of message
+ HttpMessage message = CreatePostRequestMessage(
+ deploymentOrModelName,
+ operationPath,
+ content,
+ context);
+ message.BufferResponse = false;
+ Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
+ return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse);
+ }
+ catch (Exception e)
+ {
+ scope.Failed(e);
+ throw;
+ }
+ }
+
+ ///
+ public virtual async Task> GetChatCompletionsStreamingAsync2(
+ string deploymentOrModelName,
+ ChatCompletionsOptions chatCompletionsOptions,
+ CancellationToken cancellationToken = default)
+ {
+ Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
+ Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
+
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2");
+ scope.Start();
+
+ chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
+ chatCompletionsOptions.InternalShouldStreamResponse = true;
+
+ string operationPath = GetOperationPath(chatCompletionsOptions);
+
+ RequestContent content = chatCompletionsOptions.ToRequestContent();
+ RequestContext context = FromCancellationToken(cancellationToken);
+
+ try
+ {
+ // Response value object takes IDisposable ownership of message
+ HttpMessage message = CreatePostRequestMessage(
+ deploymentOrModelName,
+ operationPath,
+ content,
+ context);
+ message.BufferResponse = false;
+ Response baseResponse = await _pipeline.ProcessMessageAsync(
+ message,
+ context,
+ cancellationToken).ConfigureAwait(false);
+ return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse);
+ }
+ catch (Exception e)
+ {
+ scope.Failed(e);
+ throw;
+ }
+ }
+
/// Return the computed embeddings for a given prompt.
///
/// EnumerateChatUpdates(
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ while (true)
+ {
+ SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
+ if (sseEvent == null)
+ {
+ _baseResponse.ContentStream?.Dispose();
+ break;
+ }
+
+ ReadOnlyMemory name = sseEvent.Value.FieldName;
+ if (!name.Span.SequenceEqual("data".AsSpan()))
+ throw new InvalidDataException();
+
+ ReadOnlyMemory value = sseEvent.Value.FieldValue;
+ if (value.Span.SequenceEqual("[DONE]".AsSpan()))
+ {
+ _baseResponse.ContentStream?.Dispose();
+ break;
+ }
+
+ JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
+ foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate
+ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement))
+ {
+ yield return streamingChatCompletionsUpdate;
+ }
+ }
+ }
+
+ public void Dispose()
+ {
+ Dispose(disposing: true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (!_disposedValue)
+ {
+ if (disposing)
+ {
+ _baseResponseReader.Dispose();
+ }
+
+ _disposedValue = true;
+ }
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs
new file mode 100644
index 0000000000000..4d11ca2a15229
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.Serialization.cs
@@ -0,0 +1,164 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using Azure.Core;
+
+namespace Azure.AI.OpenAI
+{
+ public partial class StreamingChatCompletionsUpdate
+ {
+ internal static List DeserializeStreamingChatCompletionsUpdates(JsonElement element)
+ {
+ var results = new List();
+ if (element.ValueKind == JsonValueKind.Null)
+ {
+ return results;
+ }
+ string id = default;
+ DateTimeOffset created = default;
+ AzureChatExtensionsMessageContext azureExtensionsContext = null;
+ ContentFilterResults requestContentFilterResults = null;
+ foreach (JsonProperty property in element.EnumerateObject())
+ {
+ if (property.NameEquals("id"u8))
+ {
+ id = property.Value.GetString();
+ continue;
+ }
+ if (property.NameEquals("created"u8))
+ {
+ created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());
+ continue;
+ }
+ // CUSTOM CODE NOTE: temporary, custom handling of forked keys for prompt filter results
+ if (property.NameEquals("prompt_annotations"u8) || property.NameEquals("prompt_filter_results"))
+ {
+ if (property.Value.ValueKind == JsonValueKind.Null)
+ {
+ continue;
+ }
+ List array = new List();
+ foreach (var item in property.Value.EnumerateArray())
+ {
+ PromptFilterResult promptFilterResult = PromptFilterResult.DeserializePromptFilterResult(item);
+ requestContentFilterResults = promptFilterResult.ContentFilterResults;
+ }
+ continue;
+ }
+ if (property.NameEquals("choices"u8))
+ {
+ foreach (JsonElement choiceElement in property.Value.EnumerateArray())
+ {
+ ChatRole? role = null;
+ string contentUpdate = null;
+ string functionName = null;
+ string authorName = null;
+ string functionArgumentsUpdate = null;
+ int choiceIndex = 0;
+ CompletionsFinishReason? finishReason = null;
+ ContentFilterResults responseContentFilterResults = null;
+ foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject())
+ {
+ if (choiceProperty.NameEquals("index"u8))
+ {
+ choiceIndex = choiceProperty.Value.GetInt32();
+ continue;
+ }
+ if (choiceProperty.NameEquals("finish_reason"u8))
+ {
+ finishReason = new CompletionsFinishReason(choiceProperty.Value.GetString());
+ continue;
+ }
+ if (choiceProperty.NameEquals("delta"u8))
+ {
+ foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject())
+ {
+ if (deltaProperty.NameEquals("role"u8))
+ {
+ role = deltaProperty.Value.GetString();
+ continue;
+ }
+ if (deltaProperty.NameEquals("name"u8))
+ {
+ authorName = deltaProperty.Value.GetString();
+ continue;
+ }
+ if (deltaProperty.NameEquals("content"u8))
+ {
+ contentUpdate = deltaProperty.Value.GetString();
+ continue;
+ }
+ if (deltaProperty.NameEquals("function_call"u8))
+ {
+ foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject())
+ {
+ if (functionProperty.NameEquals("name"u8))
+ {
+ functionName = functionProperty.Value.GetString();
+ continue;
+ }
+ if (functionProperty.NameEquals("arguments"u8))
+ {
+ functionArgumentsUpdate = functionProperty.Value.GetString();
+ }
+ }
+ }
+ if (deltaProperty.NameEquals("context"u8))
+ {
+ azureExtensionsContext = AzureChatExtensionsMessageContext
+ .DeserializeAzureChatExtensionsMessageContext(deltaProperty.Value);
+ continue;
+ }
+ }
+ }
+ if (choiceProperty.NameEquals("content_filter_results"u8))
+ {
+ if (choiceProperty.Value.EnumerateObject().Any())
+ {
+ responseContentFilterResults = ContentFilterResults.DeserializeContentFilterResults(choiceProperty.Value);
+ }
+ continue;
+ }
+ }
+ if (requestContentFilterResults is not null || responseContentFilterResults is not null)
+ {
+ azureExtensionsContext ??= new AzureChatExtensionsMessageContext();
+ azureExtensionsContext.RequestContentFilterResults = requestContentFilterResults;
+ azureExtensionsContext.ResponseContentFilterResults = responseContentFilterResults;
+ }
+ results.Add(new StreamingChatCompletionsUpdate(
+ id,
+ created,
+ choiceIndex,
+ role,
+ authorName,
+ contentUpdate,
+ finishReason,
+ functionName,
+ functionArgumentsUpdate,
+ azureExtensionsContext));
+ }
+ continue;
+ }
+ }
+ if (results.Count == 0)
+ {
+ if (requestContentFilterResults is not null)
+ {
+ azureExtensionsContext ??= new AzureChatExtensionsMessageContext()
+ {
+ RequestContentFilterResults = requestContentFilterResults,
+ };
+ }
+ results.Add(new StreamingChatCompletionsUpdate(id, created, azureExtensionsContext: azureExtensionsContext));
+ }
+ return results;
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
new file mode 100644
index 0000000000000..c16e0d20f2e5d
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using Azure.Core;
+
+namespace Azure.AI.OpenAI
+{
+ public partial class StreamingChatCompletionsUpdate
+ {
+ public string Id { get; }
+
+ public DateTimeOffset Created { get; }
+
+ public ChatRole? Role { get; }
+
+ public string ContentUpdate { get; }
+
+ public string FunctionName { get; }
+
+ public string FunctionArgumentsUpdate { get; }
+
+ public string AuthorName { get; }
+
+ public CompletionsFinishReason? FinishReason { get; }
+
+ public int? ChoiceIndex { get; }
+
+ public AzureChatExtensionsMessageContext AzureExtensionsContext { get; }
+
+ internal StreamingChatCompletionsUpdate(
+ string id,
+ DateTimeOffset created,
+ int? choiceIndex = null,
+ ChatRole? role = null,
+ string authorName = null,
+ string contentUpdate = null,
+ CompletionsFinishReason? finishReason = null,
+ string functionName = null,
+ string functionArgumentsUpdate = null,
+ AzureChatExtensionsMessageContext azureExtensionsContext = null)
+ {
+ Id = id;
+ Created = created;
+ ChoiceIndex = choiceIndex;
+ Role = role;
+ AuthorName = authorName;
+ ContentUpdate = contentUpdate;
+ FinishReason = finishReason;
+ FunctionName = functionName;
+ FunctionArgumentsUpdate = functionArgumentsUpdate;
+ AzureExtensionsContext = azureExtensionsContext;
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
index f7eaaf696bc35..7376674572ee8 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
@@ -273,6 +273,79 @@ Response streamingResponse
AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1));
}
+ [RecordedTest]
+ [TestCase(OpenAIClientServiceTarget.Azure)]
+ [TestCase(OpenAIClientServiceTarget.NonAzure)]
+ public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTarget)
+ {
+ OpenAIClient client = GetTestClient(serviceTarget);
+ string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
+ serviceTarget,
+ OpenAIClientScenario.ChatCompletions);
+ var requestOptions = new ChatCompletionsOptions()
+ {
+ Messages =
+ {
+ new ChatMessage(ChatRole.System, "You are a helpful assistant."),
+ new ChatMessage(ChatRole.User, "Can you help me?"),
+ new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"),
+ new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"),
+ },
+ MaxTokens = 512,
+ };
+ Response streamingResponse
+ = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions);
+ Assert.That(streamingResponse, Is.Not.Null);
+ using StreamingChatCompletions2 streamingChatCompletions = streamingResponse.Value;
+ Assert.That(streamingChatCompletions, Is.InstanceOf());
+
+ StringBuilder contentBuilder = new StringBuilder();
+ bool gotRole = false;
+ bool gotRequestContentFilterResults = false;
+ bool gotResponseContentFilterResults = false;
+
+ await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates())
+ {
+ Assert.That(chatUpdate, Is.Not.Null);
+
+ if (chatUpdate.AzureExtensionsContext?.RequestContentFilterResults is null)
+ {
+ Assert.That(chatUpdate.Id, Is.Not.Null.Or.Empty);
+ Assert.That(chatUpdate.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1))));
+ Assert.That(chatUpdate.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7)));
+ }
+ if (chatUpdate.Role.HasValue)
+ {
+ Assert.IsFalse(gotRole);
+ Assert.That(chatUpdate.Role.Value, Is.EqualTo(ChatRole.Assistant));
+ gotRole = true;
+ }
+ if (chatUpdate.ContentUpdate is not null)
+ {
+ contentBuilder.Append(chatUpdate.ContentUpdate);
+ }
+ if (chatUpdate.AzureExtensionsContext?.RequestContentFilterResults is not null)
+ {
+ Assert.IsFalse(gotRequestContentFilterResults);
+ AssertExpectedContentFilterResults(chatUpdate.AzureExtensionsContext.RequestContentFilterResults, serviceTarget);
+ gotRequestContentFilterResults = true;
+ }
+ if (chatUpdate.AzureExtensionsContext?.ResponseContentFilterResults is not null)
+ {
+ AssertExpectedContentFilterResults(chatUpdate.AzureExtensionsContext.ResponseContentFilterResults, serviceTarget);
+ gotResponseContentFilterResults = true;
+ }
+ }
+
+ Assert.IsTrue(gotRole);
+ Assert.That(contentBuilder.ToString(), Is.Not.Null.Or.Empty);
+ if (serviceTarget == OpenAIClientServiceTarget.Azure)
+ {
+ Assert.IsTrue(gotRequestContentFilterResults);
+ Assert.IsTrue(gotResponseContentFilterResults);
+ }
+ }
+
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
[TestCase(OpenAIClientServiceTarget.NonAzure)]
From 6dfd14a8b19960021f448193643d7f7fca1e12b7 Mon Sep 17 00:00:00 2001
From: Travis Wilson <35748617+trrwilson@users.noreply.github.com>
Date: Wed, 18 Oct 2023 09:45:36 -0700
Subject: [PATCH 02/19] add ported functions test
---
.../api/Azure.AI.OpenAI.netstandard2.0.cs | 25 +++++++
.../src/Custom/StreamingChatCompletions2.cs | 9 ++-
.../tests/ChatFunctionsTests.cs | 68 +++++++++++++++++++
3 files changed, 101 insertions(+), 1 deletion(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
index a4b3f10dcf7df..8fd561d1c7c8d 100644
--- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
+++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
@@ -114,6 +114,8 @@ public partial class AzureChatExtensionsMessageContext
{
public AzureChatExtensionsMessageContext() { }
public System.Collections.Generic.IList Messages { get { throw null; } }
+ public Azure.AI.OpenAI.ContentFilterResults RequestContentFilterResults { get { throw null; } }
+ public Azure.AI.OpenAI.ContentFilterResults ResponseContentFilterResults { get { throw null; } }
}
public partial class AzureChatExtensionsOptions
{
@@ -479,7 +481,9 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent
public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual Azure.Response GetChatCompletionsStreaming2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -528,6 +532,27 @@ public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
+ public partial class StreamingChatCompletions2 : System.IDisposable
+ {
+ internal StreamingChatCompletions2() { }
+ public void Dispose() { }
+ protected virtual void Dispose(bool disposing) { }
+ public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ }
+ public partial class StreamingChatCompletionsUpdate
+ {
+ internal StreamingChatCompletionsUpdate() { }
+ public string AuthorName { get { throw null; } }
+ public Azure.AI.OpenAI.AzureChatExtensionsMessageContext AzureExtensionsContext { get { throw null; } }
+ public int? ChoiceIndex { get { throw null; } }
+ public string ContentUpdate { get { throw null; } }
+ public System.DateTimeOffset Created { get { throw null; } }
+ public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } }
+ public string FunctionArgumentsUpdate { get { throw null; } }
+ public string FunctionName { get { throw null; } }
+ public string Id { get { throw null; } }
+ public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } }
+ }
public partial class StreamingChoice
{
internal StreamingChoice() { }
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
index fb16e591e6766..967dea227f2a7 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
@@ -27,7 +27,13 @@ internal StreamingChatCompletions2(Response response)
public async IAsyncEnumerable EnumerateChatUpdates(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- while (true)
+ // Open clarification points:
+ // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested?
+ // - What should happen if the stream is held too long and it times out?
+ // - Should enumeration be repeatable, e.g. via locally caching a collection of the updates?
+ // - What if the initial enumeration was cancelled?
+ // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently?
+ while (!cancellationToken.IsCancellationRequested)
{
SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
if (sseEvent == null)
@@ -54,6 +60,7 @@ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sse
yield return streamingChatCompletionsUpdate;
}
}
+ _baseResponse?.ContentStream?.Dispose();
}
public void Dispose()
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
index 36b612ce299dd..57be0f978564a 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
@@ -135,6 +135,55 @@ Response response
Assert.That(argumentsBuilder.Length, Is.GreaterThan(0));
}
+ [RecordedTest]
+ [TestCase(OpenAIClientServiceTarget.Azure)]
+ [TestCase(OpenAIClientServiceTarget.NonAzure)]
+ public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceTarget)
+ {
+ OpenAIClient client = GetTestClient(serviceTarget);
+ string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions);
+
+ var requestOptions = new ChatCompletionsOptions()
+ {
+ Functions = { s_futureTemperatureFunction, s_wordOfTheDayFunction },
+ Messages =
+ {
+ new ChatMessage(ChatRole.System, "You are a helpful assistant."),
+ new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"),
+ },
+ MaxTokens = 512,
+ };
+
+ Response response
+ = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions);
+ Assert.That(response, Is.Not.Null);
+
+ using StreamingChatCompletions2 streamingChatCompletions = response.Value;
+
+ ChatRole? streamedRole = default;
+ string functionName = null;
+ StringBuilder argumentsBuilder = new();
+
+ await foreach (StreamingChatCompletionsUpdate chatUpdate in streamingChatCompletions.EnumerateChatUpdates())
+ {
+ if (chatUpdate.Role.HasValue)
+ {
+ Assert.That(streamedRole, Is.Null, "role should only appear once");
+ streamedRole = chatUpdate.Role.Value;
+ }
+ if (!string.IsNullOrEmpty(chatUpdate.FunctionName))
+ {
+ Assert.That(functionName, Is.Null, "function_name should only appear once");
+ functionName = chatUpdate.FunctionName;
+ }
+ argumentsBuilder.Append(chatUpdate.FunctionArgumentsUpdate);
+ }
+
+ Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant));
+ Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name));
+ Assert.That(argumentsBuilder.Length, Is.GreaterThan(0));
+ }
+
private static readonly FunctionDefinition s_futureTemperatureFunction = new()
{
Name = "get_future_temperature",
@@ -159,4 +208,23 @@ Response response
},
new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
};
+
+ private static readonly FunctionDefinition s_wordOfTheDayFunction = new()
+ {
+ Name = "get_word_of_the_day",
+ Description = "requests a featured word for a given day of the week",
+ Parameters = BinaryData.FromObjectAsJson(new
+ {
+ Type = "object",
+ Properties = new
+ {
+ DayOfWeek = new
+ {
+ Type = "string",
+ Description = "the name of a day of the week, like Wednesday",
+ },
+ }
+ },
+ new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase })
+ };
}
From 4417b07dc6a19f238d794c63db85e05cfb19d7d5 Mon Sep 17 00:00:00 2001
From: Travis Wilson <35748617+trrwilson@users.noreply.github.com>
Date: Wed, 18 Oct 2023 10:17:48 -0700
Subject: [PATCH 03/19] remaining tests ported
---
.../src/Custom/OpenAIClient.cs | 198 +++++++++---------
.../tests/AzureChatExtensionsTests.cs | 96 +++++++--
.../tests/ChatFunctionsTests.cs | 100 ++++-----
.../tests/OpenAIInferenceTests.cs | 98 ++++-----
.../tests/Samples/Sample04_StreamingChat.cs | 49 ++++-
5 files changed, 324 insertions(+), 217 deletions(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
index d55ad4462a585..6c44d50e6cc19 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
@@ -400,105 +400,105 @@ public virtual async Task> GetChatCompletionsAsync(
}
}
- ///
- /// Begin a chat completions request and get an object that can stream response data as it becomes
- /// available.
- ///
- ///
- ///
- ///
- ///
- /// the chat completions options for this chat completions request.
- ///
- ///
- /// a cancellation token that can be used to cancel the initial request or ongoing streaming operation.
- ///
- ///
- /// or is null.
- ///
- /// Service returned a non-success status code.
- /// The response returned from the service.
- public virtual Response GetChatCompletionsStreaming(
- string deploymentOrModelName,
- ChatCompletionsOptions chatCompletionsOptions,
- CancellationToken cancellationToken = default)
- {
- Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
- Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
-
- using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
- scope.Start();
-
- chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
- chatCompletionsOptions.InternalShouldStreamResponse = true;
-
- string operationPath = GetOperationPath(chatCompletionsOptions);
-
- RequestContent content = chatCompletionsOptions.ToRequestContent();
- RequestContext context = FromCancellationToken(cancellationToken);
-
- try
- {
- // Response value object takes IDisposable ownership of message
- HttpMessage message = CreatePostRequestMessage(
- deploymentOrModelName,
- operationPath,
- content,
- context);
- message.BufferResponse = false;
- Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
- return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
- }
- catch (Exception e)
- {
- scope.Failed(e);
- throw;
- }
- }
-
- ///
- public virtual async Task> GetChatCompletionsStreamingAsync(
- string deploymentOrModelName,
- ChatCompletionsOptions chatCompletionsOptions,
- CancellationToken cancellationToken = default)
- {
- Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
- Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
-
- using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
- scope.Start();
-
- chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
- chatCompletionsOptions.InternalShouldStreamResponse = true;
-
- string operationPath = GetOperationPath(chatCompletionsOptions);
-
- RequestContent content = chatCompletionsOptions.ToRequestContent();
- RequestContext context = FromCancellationToken(cancellationToken);
-
- try
- {
- // Response value object takes IDisposable ownership of message
- HttpMessage message = CreatePostRequestMessage(
- deploymentOrModelName,
- operationPath,
- content,
- context);
- message.BufferResponse = false;
- Response baseResponse = await _pipeline.ProcessMessageAsync(
- message,
- context,
- cancellationToken).ConfigureAwait(false);
- return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
- }
- catch (Exception e)
- {
- scope.Failed(e);
- throw;
- }
- }
+ /////
+ ///// Begin a chat completions request and get an object that can stream response data as it becomes
+ ///// available.
+ /////
+ /////
+ /////
+ /////
+ /////
+ ///// the chat completions options for this chat completions request.
+ /////
+ /////
+ ///// a cancellation token that can be used to cancel the initial request or ongoing streaming operation.
+ /////
+ /////
+ ///// or is null.
+ /////
+ ///// Service returned a non-success status code.
+ ///// The response returned from the service.
+ //public virtual Response GetChatCompletionsStreaming(
+ // string deploymentOrModelName,
+ // ChatCompletionsOptions chatCompletionsOptions,
+ // CancellationToken cancellationToken = default)
+ //{
+ // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
+ // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
+
+ // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
+ // scope.Start();
+
+ // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
+ // chatCompletionsOptions.InternalShouldStreamResponse = true;
+
+ // string operationPath = GetOperationPath(chatCompletionsOptions);
+
+ // RequestContent content = chatCompletionsOptions.ToRequestContent();
+ // RequestContext context = FromCancellationToken(cancellationToken);
+
+ // try
+ // {
+ // // Response value object takes IDisposable ownership of message
+ // HttpMessage message = CreatePostRequestMessage(
+ // deploymentOrModelName,
+ // operationPath,
+ // content,
+ // context);
+ // message.BufferResponse = false;
+ // Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
+ // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
+ // }
+ // catch (Exception e)
+ // {
+ // scope.Failed(e);
+ // throw;
+ // }
+ //}
+
+ /////
+ //public virtual async Task> GetChatCompletionsStreamingAsync(
+ // string deploymentOrModelName,
+ // ChatCompletionsOptions chatCompletionsOptions,
+ // CancellationToken cancellationToken = default)
+ //{
+ // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
+ // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
+
+ // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
+ // scope.Start();
+
+ // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
+ // chatCompletionsOptions.InternalShouldStreamResponse = true;
+
+ // string operationPath = GetOperationPath(chatCompletionsOptions);
+
+ // RequestContent content = chatCompletionsOptions.ToRequestContent();
+ // RequestContext context = FromCancellationToken(cancellationToken);
+
+ // try
+ // {
+ // // Response value object takes IDisposable ownership of message
+ // HttpMessage message = CreatePostRequestMessage(
+ // deploymentOrModelName,
+ // operationPath,
+ // content,
+ // context);
+ // message.BufferResponse = false;
+ // Response baseResponse = await _pipeline.ProcessMessageAsync(
+ // message,
+ // context,
+ // cancellationToken).ConfigureAwait(false);
+ // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
+ // }
+ // catch (Exception e)
+ // {
+ // scope.Failed(e);
+ // throw;
+ // }
+ //}
///
/// Begin a chat completions request and get an object that can stream response data as it becomes
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
index 7f56bb370064c..8f913da5a4dfd 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core.TestFramework;
@@ -88,6 +89,68 @@ public async Task BasicSearchExtensionWorks(
Assert.That(context.Messages.First().Content.Contains("citations"));
}
+ //[RecordedTest]
+ //[TestCase(OpenAIClientServiceTarget.Azure)]
+ //public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget)
+ //{
+ // OpenAIClient client = GetTestClient(serviceTarget);
+ // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
+ // serviceTarget,
+ // OpenAIClientScenario.ChatCompletions);
+
+ // var requestOptions = new ChatCompletionsOptions()
+ // {
+ // Messages =
+ // {
+ // new ChatMessage(ChatRole.User, "What does PR complete mean?"),
+ // },
+ // MaxTokens = 512,
+ // AzureExtensionsOptions = new()
+ // {
+ // Extensions =
+ // {
+ // new AzureChatExtensionConfiguration()
+ // {
+ // Type = "AzureCognitiveSearch",
+ // Parameters = BinaryData.FromObjectAsJson(new
+ // {
+ // Endpoint = "https://openaisdktestsearch.search.windows.net",
+ // IndexName = "openai-test-index-carbon-wiki",
+ // Key = GetCognitiveSearchApiKey().Key,
+ // },
+ // new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }),
+ // },
+ // }
+ // },
+ // };
+
+ // Response response = await client.GetChatCompletionsStreamingAsync(
+ // deploymentOrModelName,
+ // requestOptions);
+ // Assert.That(response, Is.Not.Null);
+ // Assert.That(response.Value, Is.Not.Null);
+
+ // using StreamingChatCompletions streamingChatCompletions = response.Value;
+
+ // int choiceCount = 0;
+ // List messageChunks = new();
+
+ // await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming())
+ // {
+ // choiceCount++;
+ // await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming())
+ // {
+ // messageChunks.Add(chatMessage);
+ // }
+ // }
+
+ // Assert.That(choiceCount, Is.EqualTo(1));
+ // Assert.That(messageChunks, Is.Not.Null.Or.Empty);
+ // Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool)));
+ // //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant));
+ // Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content)));
+ //}
+
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget)
@@ -123,30 +186,37 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic
},
};
- Response response = await client.GetChatCompletionsStreamingAsync(
+ Response response = await client.GetChatCompletionsStreamingAsync2(
deploymentOrModelName,
requestOptions);
Assert.That(response, Is.Not.Null);
Assert.That(response.Value, Is.Not.Null);
- using StreamingChatCompletions streamingChatCompletions = response.Value;
+ using StreamingChatCompletions2 streamingChatCompletions = response.Value;
- int choiceCount = 0;
- List messageChunks = new();
+ ChatRole? streamedRole = null;
+ IEnumerable azureContextMessages = null;
+ StringBuilder contentBuilder = new();
- await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming())
+ await foreach (StreamingChatCompletionsUpdate chatUpdate in response.Value.EnumerateChatUpdates())
{
- choiceCount++;
- await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming())
+ if (chatUpdate.Role.HasValue)
+ {
+ Assert.That(streamedRole, Is.Null);
+ streamedRole = chatUpdate.Role.Value;
+ }
+ if (chatUpdate.AzureExtensionsContext?.Messages?.Count > 0)
{
- messageChunks.Add(chatMessage);
+ Assert.That(azureContextMessages, Is.Null);
+ azureContextMessages = chatUpdate.AzureExtensionsContext.Messages;
}
+ contentBuilder.Append(chatUpdate.ContentUpdate);
}
- Assert.That(choiceCount, Is.EqualTo(1));
- Assert.That(messageChunks, Is.Not.Null.Or.Empty);
- Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool)));
- //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant));
- Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content)));
+ Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant));
+ Assert.That(contentBuilder.ToString(), Is.Not.Null.Or.Empty);
+ Assert.That(azureContextMessages, Is.Not.Null.Or.Empty);
+ Assert.That(azureContextMessages.Any(contextMessage => contextMessage.Role == ChatRole.Tool));
+ Assert.That(azureContextMessages.Any(contextMessage => !string.IsNullOrEmpty(contextMessage.Content)));
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
index 57be0f978564a..c7d3289706775 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
@@ -84,56 +84,56 @@ public async Task SimpleFunctionCallWorks(OpenAIClientServiceTarget serviceTarge
Assert.That(followupResponse.Value.Choices[0].Message.Content, Is.Not.Null.Or.Empty);
}
- [RecordedTest]
- [TestCase(OpenAIClientServiceTarget.Azure)]
- [TestCase(OpenAIClientServiceTarget.NonAzure)]
- public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget)
- {
- OpenAIClient client = GetTestClient(serviceTarget);
- string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions);
-
- var requestOptions = new ChatCompletionsOptions()
- {
- Functions = { s_futureTemperatureFunction },
- Messages =
- {
- new ChatMessage(ChatRole.System, "You are a helpful assistant."),
- new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"),
- },
- MaxTokens = 512,
- };
-
- Response response
- = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
- Assert.That(response, Is.Not.Null);
-
- using StreamingChatCompletions streamingChatCompletions = response.Value;
-
- ChatRole streamedRole = default;
- string functionName = default;
- StringBuilder argumentsBuilder = new();
-
- await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
- {
- await foreach (ChatMessage message in choice.GetMessageStreaming())
- {
- if (message.Role != default)
- {
- streamedRole = message.Role;
- }
- if (message.FunctionCall?.Name != null)
- {
- Assert.That(functionName, Is.Null.Or.Empty);
- functionName = message.FunctionCall.Name;
- }
- argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty);
- }
- }
-
- Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant));
- Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name));
- Assert.That(argumentsBuilder.Length, Is.GreaterThan(0));
- }
+ //[RecordedTest]
+ //[TestCase(OpenAIClientServiceTarget.Azure)]
+ //[TestCase(OpenAIClientServiceTarget.NonAzure)]
+ //public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget)
+ //{
+ // OpenAIClient client = GetTestClient(serviceTarget);
+ // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions);
+
+ // var requestOptions = new ChatCompletionsOptions()
+ // {
+ // Functions = { s_futureTemperatureFunction },
+ // Messages =
+ // {
+ // new ChatMessage(ChatRole.System, "You are a helpful assistant."),
+ // new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"),
+ // },
+ // MaxTokens = 512,
+ // };
+
+ // Response response
+ // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
+ // Assert.That(response, Is.Not.Null);
+
+ // using StreamingChatCompletions streamingChatCompletions = response.Value;
+
+ // ChatRole streamedRole = default;
+ // string functionName = default;
+ // StringBuilder argumentsBuilder = new();
+
+ // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
+ // {
+ // await foreach (ChatMessage message in choice.GetMessageStreaming())
+ // {
+ // if (message.Role != default)
+ // {
+ // streamedRole = message.Role;
+ // }
+ // if (message.FunctionCall?.Name != null)
+ // {
+ // Assert.That(functionName, Is.Null.Or.Empty);
+ // functionName = message.FunctionCall.Name;
+ // }
+ // argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty);
+ // }
+ // }
+
+ // Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant));
+ // Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name));
+ // Assert.That(argumentsBuilder.Length, Is.GreaterThan(0));
+ //}
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
index 7376674572ee8..5b8ba0ba9f42e 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
@@ -223,55 +223,55 @@ string deploymentOrModelName
AssertExpectedContentFilterResults(firstChoice.ContentFilterResults, serviceTarget);
}
- [RecordedTest]
- [TestCase(OpenAIClientServiceTarget.Azure)]
- [TestCase(OpenAIClientServiceTarget.NonAzure)]
- public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget)
- {
- OpenAIClient client = GetTestClient(serviceTarget);
- string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
- serviceTarget,
- OpenAIClientScenario.ChatCompletions);
- var requestOptions = new ChatCompletionsOptions()
- {
- Messages =
- {
- new ChatMessage(ChatRole.System, "You are a helpful assistant."),
- new ChatMessage(ChatRole.User, "Can you help me?"),
- new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"),
- new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"),
- },
- MaxTokens = 512,
- };
- Response streamingResponse
- = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
- Assert.That(streamingResponse, Is.Not.Null);
- using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value;
- Assert.That(streamingChatCompletions, Is.InstanceOf());
-
- int totalMessages = 0;
-
- await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming())
- {
- Assert.That(streamingChoice, Is.Not.Null);
- await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming())
- {
- Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant));
- totalMessages++;
- }
- AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget);
- }
-
- Assert.That(totalMessages, Is.GreaterThan(1));
-
- // Note: these top-level values *are likely not yet populated* until *after* at least one streaming
- // choice has arrived.
- Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty);
- Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty);
- Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1))));
- Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7)));
- AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1));
- }
+ //[RecordedTest]
+ //[TestCase(OpenAIClientServiceTarget.Azure)]
+ //[TestCase(OpenAIClientServiceTarget.NonAzure)]
+ //public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget)
+ //{
+ // OpenAIClient client = GetTestClient(serviceTarget);
+ // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
+ // serviceTarget,
+ // OpenAIClientScenario.ChatCompletions);
+ // var requestOptions = new ChatCompletionsOptions()
+ // {
+ // Messages =
+ // {
+ // new ChatMessage(ChatRole.System, "You are a helpful assistant."),
+ // new ChatMessage(ChatRole.User, "Can you help me?"),
+ // new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"),
+ // new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"),
+ // },
+ // MaxTokens = 512,
+ // };
+ // Response streamingResponse
+ // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
+ // Assert.That(streamingResponse, Is.Not.Null);
+ // using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value;
+ // Assert.That(streamingChatCompletions, Is.InstanceOf());
+
+ // int totalMessages = 0;
+
+ // await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming())
+ // {
+ // Assert.That(streamingChoice, Is.Not.Null);
+ // await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming())
+ // {
+ // Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant));
+ // totalMessages++;
+ // }
+ // AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget);
+ // }
+
+ // Assert.That(totalMessages, Is.GreaterThan(1));
+
+ // // Note: these top-level values *are likely not yet populated* until *after* at least one streaming
+ // // choice has arrived.
+ // Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty);
+ // Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty);
+ // Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1))));
+ // Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7)));
+ // AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1));
+ //}
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
index 21a25241f13d2..bd0ee5d1db4e1 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
@@ -11,6 +11,39 @@ namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class StreamingChat
{
+ //[Test]
+ //[Ignore("Only verifying that the sample builds")]
+ //public async Task StreamingChatWithNonAzureOpenAI()
+ //{
+ // #region Snippet:StreamChatMessages
+ // string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com";
+ // var client = new OpenAIClient(nonAzureOpenAIApiKey, new OpenAIClientOptions());
+ // var chatCompletionsOptions = new ChatCompletionsOptions()
+ // {
+ // Messages =
+ // {
+ // new ChatMessage(ChatRole.System, "You are a helpful assistant. You will talk like a pirate."),
+ // new ChatMessage(ChatRole.User, "Can you help me?"),
+ // new ChatMessage(ChatRole.Assistant, "Arrrr! Of course, me hearty! What can I do for ye?"),
+ // new ChatMessage(ChatRole.User, "What's the best way to train a parrot?"),
+ // }
+ // };
+
+ // Response response = await client.GetChatCompletionsStreamingAsync(
+ // deploymentOrModelName: "gpt-3.5-turbo",
+ // chatCompletionsOptions);
+ // using StreamingChatCompletions streamingChatCompletions = response.Value;
+
+ // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
+ // {
+ // await foreach (ChatMessage message in choice.GetMessageStreaming())
+ // {
+ // Console.Write(message.Content);
+ // }
+ // Console.WriteLine();
+ // }
+ // #endregion
+ //}
[Test]
[Ignore("Only verifying that the sample builds")]
public async Task StreamingChatWithNonAzureOpenAI()
@@ -29,18 +62,22 @@ public async Task StreamingChatWithNonAzureOpenAI()
}
};
- Response response = await client.GetChatCompletionsStreamingAsync(
+ Response response = await client.GetChatCompletionsStreamingAsync2(
deploymentOrModelName: "gpt-3.5-turbo",
chatCompletionsOptions);
- using StreamingChatCompletions streamingChatCompletions = response.Value;
+ using StreamingChatCompletions2 streamingChatCompletions = response.Value;
- await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
+ await foreach (StreamingChatCompletionsUpdate chatUpdate
+ in streamingChatCompletions.EnumerateChatUpdates())
{
- await foreach (ChatMessage message in choice.GetMessageStreaming())
+ if (chatUpdate.Role.HasValue)
+ {
+ Console.Write($"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: ");
+ }
+ if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate))
{
- Console.Write(message.Content);
+ Console.Write(chatUpdate.ContentUpdate);
}
- Console.WriteLine();
}
#endregion
}
From ac477eee70d2216a38790c2051883ce60efe11e2 Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Wed, 18 Oct 2023 16:10:47 -0700
Subject: [PATCH 04/19] completions for consistency
---
sdk/openai/Azure.AI.OpenAI/README.md | 12 +-
.../api/Azure.AI.OpenAI.netstandard2.0.cs | 41 +---
.../src/Custom/AzureOpenAIModelFactory.cs | 60 +++---
.../src/Custom/OpenAIClient.cs | 114 +----------
.../src/Custom/StreamingChatChoice.cs | 148 --------------
.../src/Custom/StreamingChatCompletions.cs | 184 ++++--------------
.../src/Custom/StreamingChatCompletions2.cs | 85 --------
.../src/Custom/StreamingChoice.cs | 150 --------------
.../src/Custom/StreamingCompletions.cs | 184 ++++--------------
.../tests/AzureChatExtensionsTests.cs | 4 +-
.../tests/ChatFunctionsTests.cs | 6 +-
.../tests/OpenAIInferenceTests.cs | 96 +++++----
.../tests/Samples/Sample04_StreamingChat.cs | 37 +---
13 files changed, 183 insertions(+), 938 deletions(-)
delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs
delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs
diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md
index e7f4acb1bfe11..cd8a7d002be92 100644
--- a/sdk/openai/Azure.AI.OpenAI/README.md
+++ b/sdk/openai/Azure.AI.OpenAI/README.md
@@ -211,13 +211,17 @@ Response response = await client.GetChatCompletionsStr
chatCompletionsOptions);
using StreamingChatCompletions streamingChatCompletions = response.Value;
-await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
+await foreach (StreamingChatCompletionsUpdate chatUpdate
+ in streamingChatCompletions.EnumerateChatUpdates())
{
- await foreach (ChatMessage message in choice.GetMessageStreaming())
+ if (chatUpdate.Role.HasValue)
{
- Console.Write(message.Content);
+ Console.Write($"{chatUpdate.Role.Value.ToString().ToUpperInvariant()}: ");
+ }
+ if (!string.IsNullOrEmpty(chatUpdate.ContentUpdate))
+ {
+ Console.Write(chatUpdate.ContentUpdate);
}
- Console.WriteLine();
}
```
diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
index 8fd561d1c7c8d..7a2dfac1a77bc 100644
--- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
+++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
@@ -206,10 +206,9 @@ public static partial class AzureOpenAIModelFactory
public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable data = null) { throw null; }
public static Azure.AI.OpenAI.ImageLocation ImageLocation(System.Uri url = null) { throw null; }
public static Azure.AI.OpenAI.PromptFilterResult PromptFilterResult(int promptIndex = 0, Azure.AI.OpenAI.ContentFilterResults contentFilterResults = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingChatChoice StreamingChatChoice(Azure.AI.OpenAI.ChatChoice originalBaseChoice = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(Azure.AI.OpenAI.ChatCompletions baseChatCompletions = null, System.Collections.Generic.List streamingChatChoices = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingChoice StreamingChoice(Azure.AI.OpenAI.Choice originalBaseChoice = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(Azure.AI.OpenAI.Completions baseCompletions = null, System.Collections.Generic.List streamingChoices = null) { throw null; }
+ public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(System.Collections.Generic.IEnumerable updates) { throw null; }
+ public static Azure.AI.OpenAI.StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(string id, System.DateTimeOffset created, int? choiceIndex = default(int?), Azure.AI.OpenAI.ChatRole? role = default(Azure.AI.OpenAI.ChatRole?), string authorName = null, string contentUpdate = null, Azure.AI.OpenAI.CompletionsFinishReason? finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason?), string functionName = null, string functionArgumentsUpdate = null, Azure.AI.OpenAI.AzureChatExtensionsMessageContext azureExtensionsContext = null) { throw null; }
+ public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(System.Collections.Generic.IEnumerable completions) { throw null; }
}
public partial class ChatChoice
{
@@ -481,9 +480,7 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent
public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual Azure.Response GetChatCompletionsStreaming2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync2(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -514,27 +511,9 @@ internal PromptFilterResult() { }
public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } }
public int PromptIndex { get { throw null; } }
}
- public partial class StreamingChatChoice
- {
- internal StreamingChatChoice() { }
- public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } }
- public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } }
- public int? Index { get { throw null; } }
- public System.Collections.Generic.IAsyncEnumerable GetMessageStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- }
public partial class StreamingChatCompletions : System.IDisposable
{
internal StreamingChatCompletions() { }
- public System.DateTimeOffset Created { get { throw null; } }
- public string Id { get { throw null; } }
- public System.Collections.Generic.IReadOnlyList PromptFilterResults { get { throw null; } }
- public void Dispose() { }
- protected virtual void Dispose(bool disposing) { }
- public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- }
- public partial class StreamingChatCompletions2 : System.IDisposable
- {
- internal StreamingChatCompletions2() { }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -553,24 +532,12 @@ internal StreamingChatCompletionsUpdate() { }
public string Id { get { throw null; } }
public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } }
}
- public partial class StreamingChoice
- {
- internal StreamingChoice() { }
- public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } }
- public Azure.AI.OpenAI.CompletionsFinishReason? FinishReason { get { throw null; } }
- public int? Index { get { throw null; } }
- public Azure.AI.OpenAI.CompletionsLogProbabilityModel LogProbabilityModel { get { throw null; } }
- public System.Collections.Generic.IAsyncEnumerable GetTextStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- }
public partial class StreamingCompletions : System.IDisposable
{
internal StreamingCompletions() { }
- public System.DateTimeOffset Created { get { throw null; } }
- public string Id { get { throw null; } }
- public System.Collections.Generic.IReadOnlyList PromptFilterResults { get { throw null; } }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
- public System.Collections.Generic.IAsyncEnumerable GetChoicesStreaming([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public System.Collections.Generic.IAsyncEnumerable EnumerateCompletions([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
}
namespace Microsoft.Extensions.Azure
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
index a879727738372..cc90ea19685df 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
@@ -30,50 +30,52 @@ public static ChatChoice ChatChoice(
return new ChatChoice(message, index, finishReason, deltaMessage, contentFilterResults);
}
- ///
- /// Initializes a new instance of StreamingChoice for tests and mocking.
- ///
- /// An underlying Choice for this streaming representation
- /// A new instance of StreamingChoice
- public static StreamingChoice StreamingChoice(Choice originalBaseChoice = null)
- {
- return new StreamingChoice(originalBaseChoice);
- }
-
///
/// Initializes a new instance of StreamingCompletions for tests and mocking.
///
- /// The non-streaming completions to base this streaming representation on
- /// The streaming choices associated with this streaming completions
- /// A new instance of StreamingCompletions
- public static StreamingCompletions StreamingCompletions(
- Completions baseCompletions = null,
- List streamingChoices = null)
+ /// A static set of Completions instances to use for tests and mocking.
+ /// A new instance of StreamingCompletions.
+ public static StreamingCompletions StreamingCompletions(IEnumerable completions)
{
- return new StreamingCompletions(baseCompletions, streamingChoices);
+ return new StreamingCompletions(completions);
}
- ///
- /// Initializes a new instance of StreamingChatChoice for tests and mocking.
- ///
- /// An underlying ChatChoice for this streaming representation
- /// A new instance of StreamingChatChoice
- public static StreamingChatChoice StreamingChatChoice(ChatChoice originalBaseChoice = null)
+ public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(
+ string id,
+ DateTimeOffset created,
+ int? choiceIndex = null,
+ ChatRole? role = null,
+ string authorName = null,
+ string contentUpdate = null,
+ CompletionsFinishReason? finishReason = null,
+ string functionName = null,
+ string functionArgumentsUpdate = null,
+ AzureChatExtensionsMessageContext azureExtensionsContext = null)
{
- return new StreamingChatChoice(originalBaseChoice);
+ return new StreamingChatCompletionsUpdate(
+ id,
+ created,
+ choiceIndex,
+ role,
+ authorName,
+ contentUpdate,
+ finishReason,
+ functionName,
+ functionArgumentsUpdate,
+ azureExtensionsContext);
}
///
/// Initializes a new instance of StreamingChatCompletions for tests and mocking.
///
- /// The non-streaming completions to base this streaming representation on
- /// The streaming choices associated with this streaming chat completions
+ ///
+ /// A static collection of StreamingChatCompletionUpdates to use for tests and mocking.
+ ///
/// A new instance of StreamingChatCompletions
public static StreamingChatCompletions StreamingChatCompletions(
- ChatCompletions baseChatCompletions = null,
- List streamingChatChoices = null)
+ IEnumerable updates)
{
- return new StreamingChatCompletions(baseChatCompletions, streamingChatChoices);
+ return new StreamingChatCompletions(updates);
}
/// Initializes a new instance of AudioTranscription.
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
index 6c44d50e6cc19..e46691d42ae88 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
@@ -400,106 +400,6 @@ public virtual async Task> GetChatCompletionsAsync(
}
}
- /////
- ///// Begin a chat completions request and get an object that can stream response data as it becomes
- ///// available.
- /////
- /////
- /////
- /////
- /////
- ///// the chat completions options for this chat completions request.
- /////
- /////
- ///// a cancellation token that can be used to cancel the initial request or ongoing streaming operation.
- /////
- /////
- ///// or is null.
- /////
- ///// Service returned a non-success status code.
- ///// The response returned from the service.
- //public virtual Response GetChatCompletionsStreaming(
- // string deploymentOrModelName,
- // ChatCompletionsOptions chatCompletionsOptions,
- // CancellationToken cancellationToken = default)
- //{
- // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
- // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
-
- // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
- // scope.Start();
-
- // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
- // chatCompletionsOptions.InternalShouldStreamResponse = true;
-
- // string operationPath = GetOperationPath(chatCompletionsOptions);
-
- // RequestContent content = chatCompletionsOptions.ToRequestContent();
- // RequestContext context = FromCancellationToken(cancellationToken);
-
- // try
- // {
- // // Response value object takes IDisposable ownership of message
- // HttpMessage message = CreatePostRequestMessage(
- // deploymentOrModelName,
- // operationPath,
- // content,
- // context);
- // message.BufferResponse = false;
- // Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
- // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
- // }
- // catch (Exception e)
- // {
- // scope.Failed(e);
- // throw;
- // }
- //}
-
- /////
- //public virtual async Task> GetChatCompletionsStreamingAsync(
- // string deploymentOrModelName,
- // ChatCompletionsOptions chatCompletionsOptions,
- // CancellationToken cancellationToken = default)
- //{
- // Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
- // Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
-
- // using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
- // scope.Start();
-
- // chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
- // chatCompletionsOptions.InternalShouldStreamResponse = true;
-
- // string operationPath = GetOperationPath(chatCompletionsOptions);
-
- // RequestContent content = chatCompletionsOptions.ToRequestContent();
- // RequestContext context = FromCancellationToken(cancellationToken);
-
- // try
- // {
- // // Response value object takes IDisposable ownership of message
- // HttpMessage message = CreatePostRequestMessage(
- // deploymentOrModelName,
- // operationPath,
- // content,
- // context);
- // message.BufferResponse = false;
- // Response baseResponse = await _pipeline.ProcessMessageAsync(
- // message,
- // context,
- // cancellationToken).ConfigureAwait(false);
- // return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
- // }
- // catch (Exception e)
- // {
- // scope.Failed(e);
- // throw;
- // }
- //}
-
///
/// Begin a chat completions request and get an object that can stream response data as it becomes
/// available.
@@ -520,7 +420,7 @@ public virtual async Task> GetChatCompletionsAsync(
///
/// Service returned a non-success status code.
/// The response returned from the service.
- public virtual Response GetChatCompletionsStreaming2(
+ public virtual Response GetChatCompletionsStreaming(
string deploymentOrModelName,
ChatCompletionsOptions chatCompletionsOptions,
CancellationToken cancellationToken = default)
@@ -528,7 +428,7 @@ public virtual Response GetChatCompletionsStreaming2(
Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
- using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
scope.Start();
chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
@@ -549,7 +449,7 @@ public virtual Response GetChatCompletionsStreaming2(
context);
message.BufferResponse = false;
Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
- return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse);
+ return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
}
catch (Exception e)
{
@@ -558,8 +458,8 @@ public virtual Response GetChatCompletionsStreaming2(
}
}
- ///
- public virtual async Task> GetChatCompletionsStreamingAsync2(
+ ///
+ public virtual async Task> GetChatCompletionsStreamingAsync(
string deploymentOrModelName,
ChatCompletionsOptions chatCompletionsOptions,
CancellationToken cancellationToken = default)
@@ -567,7 +467,7 @@ public virtual async Task> GetChatCompletion
Argument.AssertNotNull(deploymentOrModelName, nameof(deploymentOrModelName));
Argument.AssertNotNull(chatCompletionsOptions, nameof(chatCompletionsOptions));
- using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming2");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetChatCompletionsStreaming");
scope.Start();
chatCompletionsOptions.InternalNonAzureModelName = _isConfiguredForAzureOpenAI ? null : deploymentOrModelName;
@@ -591,7 +491,7 @@ public virtual async Task> GetChatCompletion
message,
context,
cancellationToken).ConfigureAwait(false);
- return Response.FromValue(new StreamingChatCompletions2(baseResponse), baseResponse);
+ return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
}
catch (Exception e)
{
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs
deleted file mode 100644
index 0a631b29116fc..0000000000000
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatChoice.cs
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Threading;
-
-namespace Azure.AI.OpenAI
-{
- public class StreamingChatChoice
- {
- private readonly IList _baseChoices;
- private readonly object _baseChoicesLock = new object();
- private readonly AsyncAutoResetEvent _updateAvailableEvent;
-
- ///
- /// Gets the response index associated with this StreamingChoice as represented relative to other Choices
- /// in the same Completions response.
- ///
- ///
- /// Indices may be used to correlate individual Choices within a Completions result to their configured
- /// prompts as provided in the request. As an example, if two Choices are requested for each of four prompts,
- /// the Choices with indices 0 and 1 will correspond to the first prompt, 2 and 3 to the second, and so on.
- ///
- public int? Index => GetLocked(() => _baseChoices.Last().Index);
-
- ///
- /// Gets a value representing why response generation ended when producing this StreamingChoice.
- ///
- ///
- /// Normal termination typically provides "stop" and encountering token limits in a request typically
- /// provides "length." If no value is present, this StreamingChoice is still in progress.
- ///
- public CompletionsFinishReason? FinishReason => GetLocked(() => _baseChoices.Last().FinishReason);
-
- ///
- /// Information about the content filtering category (hate, sexual, violence, self_harm), if it
- /// has been detected, as well as the severity level (very_low, low, medium, high-scale that
- /// determines the intensity and risk level of harmful content) and if it has been filtered or not.
- ///
- public ContentFilterResults ContentFilterResults
- => GetLocked(() =>
- {
- return _baseChoices
- .LastOrDefault(baseChoice => baseChoice.ContentFilterResults != null && baseChoice.ContentFilterResults.Hate != null)
- ?.ContentFilterResults;
- });
-
- internal ChatMessage StreamingDeltaMessage { get; set; }
-
- private bool _isFinishedStreaming { get; set; } = false;
-
- private Exception _pumpException { get; set; }
-
- internal StreamingChatChoice(ChatChoice originalBaseChoice)
- {
- _baseChoices = new List() { originalBaseChoice };
- _updateAvailableEvent = new AsyncAutoResetEvent();
- }
-
- internal void UpdateFromEventStreamChatChoice(ChatChoice streamingChatChoice)
- {
- lock (_baseChoicesLock)
- {
- _baseChoices.Add(streamingChatChoice);
- }
- if (streamingChatChoice.FinishReason != null)
- {
- EnsureFinishStreaming();
- }
- _updateAvailableEvent.Set();
- }
-
- ///
- /// Gets an asynchronous enumeration that will retrieve the Completion text associated with this Choice as it
- /// becomes available. Each string will provide one or more tokens, including whitespace, and a full
- /// concatenation of all enumerated strings is functionally equivalent to the single Text property on a
- /// non-streaming Completions Choice.
- ///
- ///
- ///
- public async IAsyncEnumerable GetMessageStreaming(
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- bool isFinalIndex = false;
- for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++)
- {
- bool doneWaiting = false;
- while (!doneWaiting)
- {
- lock (_baseChoicesLock)
- {
- ChatChoice mostRecentChoice = _baseChoices.Last();
-
- doneWaiting = _isFinishedStreaming || i < _baseChoices.Count;
- isFinalIndex = _isFinishedStreaming && i >= _baseChoices.Count - 1;
- }
-
- if (!doneWaiting)
- {
- await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false);
- }
- }
-
- if (_pumpException != null)
- {
- throw _pumpException;
- }
-
- ChatMessage message = null;
-
- lock (_baseChoicesLock)
- {
- if (i < _baseChoices.Count)
- {
- message = _baseChoices[i].InternalStreamingDeltaMessage;
- message.Role = _baseChoices.First().InternalStreamingDeltaMessage.Role;
- }
- }
-
- if (message != null)
- {
- yield return message;
- }
- }
- }
-
- internal void EnsureFinishStreaming(Exception pumpException = null)
- {
- if (!_isFinishedStreaming)
- {
- _isFinishedStreaming = true;
- _pumpException = pumpException;
- _updateAvailableEvent.Set();
- }
- }
-
- private T GetLocked(Func func)
- {
- lock (_baseChoicesLock)
- {
- return func.Invoke();
- }
- }
- }
-}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
index 71dd993e240d8..fe868d29f6e60 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
@@ -8,7 +8,6 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
-using System.Threading.Tasks;
using Azure.Core.Sse;
namespace Azure.AI.OpenAI
@@ -17,169 +16,66 @@ public class StreamingChatCompletions : IDisposable
{
private readonly Response _baseResponse;
private readonly SseReader _baseResponseReader;
- private readonly IList _baseChatCompletions;
- private readonly object _baseCompletionsLock = new();
- private readonly IList _streamingChatChoices;
- private readonly object _streamingChoicesLock = new();
- private readonly AsyncAutoResetEvent _updateAvailableEvent;
- private bool _streamingTaskComplete;
+ private List _cachedUpdates;
private bool _disposedValue;
- private Exception _pumpException;
-
- ///
- /// Gets the earliest Completion creation timestamp associated with this streamed response.
- ///
- public DateTimeOffset Created => GetLocked(() => _baseChatCompletions.Last().Created);
-
- ///
- /// Gets the unique identifier associated with this streaming Completions response.
- ///
- public string Id => GetLocked(() => _baseChatCompletions.Last().Id);
-
- ///
- /// Content filtering results for zero or more prompts in the request. In a streaming request,
- /// results for different prompts may arrive at different times or in different orders.
- ///
- public IReadOnlyList PromptFilterResults
- => GetLocked(() =>
- {
- return _baseChatCompletions.Where(singleBaseChatCompletion => singleBaseChatCompletion.PromptFilterResults != null)
- .SelectMany(singleBaseChatCompletion => singleBaseChatCompletion.PromptFilterResults)
- .OrderBy(singleBaseCompletion => singleBaseCompletion.PromptIndex)
- .ToList();
- });
internal StreamingChatCompletions(Response response)
{
_baseResponse = response;
_baseResponseReader = new SseReader(response.ContentStream);
- _updateAvailableEvent = new AsyncAutoResetEvent();
- _baseChatCompletions = new List();
- _streamingChatChoices = new List();
- _streamingTaskComplete = false;
- _ = Task.Run(async () =>
- {
- try
- {
- while (true)
- {
- SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
- if (sseEvent == null)
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- ReadOnlyMemory name = sseEvent.Value.FieldName;
- if (!name.Span.SequenceEqual("data".AsSpan()))
- throw new InvalidDataException();
-
- ReadOnlyMemory value = sseEvent.Value.FieldValue;
- if (value.Span.SequenceEqual("[DONE]".AsSpan()))
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
- ChatCompletions chatCompletionsFromSse = ChatCompletions.DeserializeChatCompletions(sseMessageJson.RootElement);
-
- lock (_baseCompletionsLock)
- {
- _baseChatCompletions.Add(chatCompletionsFromSse);
- }
+ _cachedUpdates = new();
+ }
- foreach (ChatChoice chatChoiceFromSse in chatCompletionsFromSse.Choices)
- {
- lock (_streamingChoicesLock)
- {
- StreamingChatChoice existingStreamingChoice = _streamingChatChoices
- .FirstOrDefault(chatChoice => chatChoice.Index == chatChoiceFromSse.Index);
- if (existingStreamingChoice == null)
- {
- StreamingChatChoice newStreamingChatChoice = new(chatChoiceFromSse);
- _streamingChatChoices.Add(newStreamingChatChoice);
- _updateAvailableEvent.Set();
- }
- else
- {
- existingStreamingChoice.UpdateFromEventStreamChatChoice(chatChoiceFromSse);
- }
- }
- }
- }
- }
- catch (Exception pumpException)
- {
- _pumpException = pumpException;
- }
- finally
- {
- lock (_streamingChoicesLock)
- {
- // If anything went wrong and a StreamingChatChoice didn't naturally determine it was complete
- // based on a non-null finish reason, ensure that nothing's left incomplete (and potentially
- // hanging!) now.
- foreach (StreamingChatChoice streamingChatChoice in _streamingChatChoices)
- {
- streamingChatChoice.EnsureFinishStreaming(_pumpException);
- }
- }
- _streamingTaskComplete = true;
- _updateAvailableEvent.Set();
- }
- });
+ internal StreamingChatCompletions(IEnumerable updates)
+ {
+ _cachedUpdates.AddRange(updates);
}
- public async IAsyncEnumerable GetChoicesStreaming(
+ public async IAsyncEnumerable EnumerateChatUpdates(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- bool isFinalIndex = false;
- for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++)
+ if (_cachedUpdates.Any())
{
- bool doneWaiting = false;
- while (!doneWaiting)
+ foreach (StreamingChatCompletionsUpdate update in _cachedUpdates)
{
- lock (_streamingChoicesLock)
- {
- doneWaiting = _streamingTaskComplete || i < _streamingChatChoices.Count;
- isFinalIndex = _streamingTaskComplete && i >= _streamingChatChoices.Count - 1;
- }
-
- if (!doneWaiting)
- {
- await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false);
- }
+ yield return update;
}
+ yield break;
+ }
- if (_pumpException != null)
+ // Open clarification points:
+ // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested?
+ // - What should happen if the stream is held too long and it times out?
+ // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently?
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
+ if (sseEvent == null)
{
- throw _pumpException;
+ _baseResponse.ContentStream?.Dispose();
+ break;
}
- StreamingChatChoice newChatChoice = null;
- lock (_streamingChoicesLock)
+ ReadOnlyMemory name = sseEvent.Value.FieldName;
+ if (!name.Span.SequenceEqual("data".AsSpan()))
+ throw new InvalidDataException();
+
+ ReadOnlyMemory value = sseEvent.Value.FieldValue;
+ if (value.Span.SequenceEqual("[DONE]".AsSpan()))
{
- if (i < _streamingChatChoices.Count)
- {
- newChatChoice = _streamingChatChoices[i];
- }
+ _baseResponse.ContentStream?.Dispose();
+ break;
}
- if (newChatChoice != null)
+ JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
+ foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate
+ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement))
{
- yield return newChatChoice;
+ _cachedUpdates.Add(streamingChatCompletionsUpdate);
+ yield return streamingChatCompletionsUpdate;
}
}
- }
-
- internal StreamingChatCompletions(
- ChatCompletions baseChatCompletions = null,
- List streamingChatChoices = null)
- {
- _baseChatCompletions.Add(baseChatCompletions);
- _streamingChatChoices = streamingChatChoices;
- _streamingTaskComplete = true;
+ _baseResponse?.ContentStream?.Dispose();
}
public void Dispose()
@@ -200,13 +96,5 @@ protected virtual void Dispose(bool disposing)
_disposedValue = true;
}
}
-
- private T GetLocked(Func func)
- {
- lock (_baseCompletionsLock)
- {
- return func.Invoke();
- }
- }
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
deleted file mode 100644
index 967dea227f2a7..0000000000000
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions2.cs
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Text.Json;
-using System.Threading;
-using Azure.Core.Sse;
-
-namespace Azure.AI.OpenAI
-{
- public class StreamingChatCompletions2 : IDisposable
- {
- private readonly Response _baseResponse;
- private readonly SseReader _baseResponseReader;
- private bool _disposedValue;
-
- internal StreamingChatCompletions2(Response response)
- {
- _baseResponse = response;
- _baseResponseReader = new SseReader(response.ContentStream);
- }
-
- public async IAsyncEnumerable EnumerateChatUpdates(
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- // Open clarification points:
- // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested?
- // - What should happen if the stream is held too long and it times out?
- // - Should enumeration be repeatable, e.g. via locally caching a collection of the updates?
- // - What if the initial enumeration was cancelled?
- // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently?
- while (!cancellationToken.IsCancellationRequested)
- {
- SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
- if (sseEvent == null)
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- ReadOnlyMemory name = sseEvent.Value.FieldName;
- if (!name.Span.SequenceEqual("data".AsSpan()))
- throw new InvalidDataException();
-
- ReadOnlyMemory value = sseEvent.Value.FieldValue;
- if (value.Span.SequenceEqual("[DONE]".AsSpan()))
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
- foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate
- in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement))
- {
- yield return streamingChatCompletionsUpdate;
- }
- }
- _baseResponse?.ContentStream?.Dispose();
- }
-
- public void Dispose()
- {
- Dispose(disposing: true);
- GC.SuppressFinalize(this);
- }
-
- protected virtual void Dispose(bool disposing)
- {
- if (!_disposedValue)
- {
- if (disposing)
- {
- _baseResponseReader.Dispose();
- }
-
- _disposedValue = true;
- }
- }
- }
-}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs
deleted file mode 100644
index 51fbde8ed9a60..0000000000000
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChoice.cs
+++ /dev/null
@@ -1,150 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Threading;
-
-namespace Azure.AI.OpenAI
-{
- public class StreamingChoice
- {
- private readonly IList _baseChoices;
- private readonly object _baseChoicesLock = new object();
- private readonly AsyncAutoResetEvent _updateAvailableEvent;
-
- ///
- /// Gets the response index associated with this StreamingChoice as represented relative to other Choices
- /// in the same Completions response.
- ///
- ///
- /// Indices may be used to correlate individual Choices within a Completions result to their configured
- /// prompts as provided in the request. As an example, if two Choices are requested for each of four prompts,
- /// the Choices with indices 0 and 1 will correspond to the first prompt, 2 and 3 to the second, and so on.
- ///
- public int? Index => GetLocked(() => _baseChoices.Last().Index);
-
- ///
- /// Gets a value representing why response generation ended when producing this StreamingChoice.
- ///
- ///
- /// Normal termination typically provides "stop" and encountering token limits in a request typically
- /// provides "length." If no value is present, this StreamingChoice is still in progress.
- ///
- public CompletionsFinishReason? FinishReason => GetLocked(() => _baseChoices.Last().FinishReason);
-
- ///
- /// Information about the content filtering category (hate, sexual, violence, self_harm), if it
- /// has been detected, as well as the severity level (very_low, low, medium, high-scale that
- /// determines the intensity and risk level of harmful content) and if it has been filtered or not.
- ///
- public ContentFilterResults ContentFilterResults
- => GetLocked(() =>
- {
- return _baseChoices
- .LastOrDefault(baseChoice => baseChoice.ContentFilterResults != null && baseChoice.ContentFilterResults.Hate != null)
- ?.ContentFilterResults;
- });
-
- private bool _isFinishedStreaming { get; set; } = false;
-
- private Exception _pumpException { get; set; }
-
- ///
- /// Gets the log probabilities associated with tokens in this Choice.
- ///
- public CompletionsLogProbabilityModel LogProbabilityModel
- => GetLocked(() => _baseChoices.Last().LogProbabilityModel);
-
- internal StreamingChoice(Choice originalBaseChoice)
- {
- _baseChoices = new List() { originalBaseChoice };
- _updateAvailableEvent = new AsyncAutoResetEvent();
- }
-
- internal void UpdateFromEventStreamChoice(Choice streamingChoice)
- {
- lock (_baseChoicesLock)
- {
- _baseChoices.Add(streamingChoice);
- }
- if (streamingChoice.FinishReason != null)
- {
- EnsureFinishStreaming();
- }
- _updateAvailableEvent.Set();
- }
-
- ///
- /// Gets an asynchronous enumeration that will retrieve the Completion text associated with this Choice as it
- /// becomes available. Each string will provide one or more tokens, including whitespace, and a full
- /// concatenation of all enumerated strings is functionally equivalent to the single Text property on a
- /// non-streaming Completions Choice.
- ///
- ///
- ///
- public async IAsyncEnumerable GetTextStreaming(
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- bool isFinalIndex = false;
- for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++)
- {
- bool doneWaiting = false;
- while (!doneWaiting)
- {
- lock (_baseChoicesLock)
- {
- Choice mostRecentChoice = _baseChoices.Last();
-
- doneWaiting = _isFinishedStreaming || i < _baseChoices.Count;
- isFinalIndex = _isFinishedStreaming && i >= _baseChoices.Count - 1;
- }
-
- if (!doneWaiting)
- {
- await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false);
- }
- }
-
- if (_pumpException != null)
- {
- throw _pumpException;
- }
-
- string newText = string.Empty;
- lock (_baseChoicesLock)
- {
- if (i < _baseChoices.Count)
- {
- newText = _baseChoices[i].Text;
- }
- }
-
- if (!string.IsNullOrEmpty(newText))
- {
- yield return newText;
- }
- }
- }
-
- internal void EnsureFinishStreaming(Exception pumpException = null)
- {
- if (!_isFinishedStreaming)
- {
- _isFinishedStreaming = true;
- _pumpException = pumpException;
- _updateAvailableEvent.Set();
- }
- }
-
- private T GetLocked(Func func)
- {
- lock (_baseChoicesLock)
- {
- return func.Invoke();
- }
- }
- }
-}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
index 70796d90f156c..fa032a2f28f50 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
@@ -17,169 +17,63 @@ public class StreamingCompletions : IDisposable
{
private readonly Response _baseResponse;
private readonly SseReader _baseResponseReader;
- private readonly IList _baseCompletions;
- private readonly object _baseCompletionsLock = new();
- private readonly IList _streamingChoices;
- private readonly object _streamingChoicesLock = new();
- private readonly AsyncAutoResetEvent _updateAvailableEvent;
- private bool _streamingTaskComplete;
+ private List _cachedCompletions;
private bool _disposedValue;
- private Exception _pumpException;
-
- ///
- /// Gets the earliest Completion creation timestamp associated with this streamed response.
- ///
- public DateTimeOffset Created => GetLocked(() => _baseCompletions.Last().Created);
-
- ///
- /// Gets the unique identifier associated with this streaming Completions response.
- ///
- public string Id => GetLocked(() => _baseCompletions.Last().Id);
-
- ///
- /// Content filtering results for zero or more prompts in the request. In a streaming request,
- /// results for different prompts may arrive at different times or in different orders.
- ///
- public IReadOnlyList PromptFilterResults
- => GetLocked(() =>
- {
- return _baseCompletions.Where(singleBaseCompletion => singleBaseCompletion.PromptFilterResults != null)
- .SelectMany(singleBaseCompletion => singleBaseCompletion.PromptFilterResults)
- .OrderBy(singleBaseCompletion => singleBaseCompletion.PromptIndex)
- .ToList();
- });
internal StreamingCompletions(Response response)
{
_baseResponse = response;
_baseResponseReader = new SseReader(response.ContentStream);
- _updateAvailableEvent = new AsyncAutoResetEvent();
- _baseCompletions = new List();
- _streamingChoices = new List();
- _streamingTaskComplete = false;
- _ = Task.Run(async () =>
- {
- try
- {
- while (true)
- {
- SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
- if (sseEvent == null)
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- ReadOnlyMemory name = sseEvent.Value.FieldName;
- if (!name.Span.SequenceEqual("data".AsSpan()))
- throw new InvalidDataException();
-
- ReadOnlyMemory value = sseEvent.Value.FieldValue;
- if (value.Span.SequenceEqual("[DONE]".AsSpan()))
- {
- _baseResponse.ContentStream?.Dispose();
- break;
- }
-
- JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
- Completions completionsFromSse = Completions.DeserializeCompletions(sseMessageJson.RootElement);
-
- lock (_baseCompletionsLock)
- {
- _baseCompletions.Add(completionsFromSse);
- }
+ _cachedCompletions = new();
+ }
- foreach (Choice choiceFromSse in completionsFromSse.Choices)
- {
- lock (_streamingChoicesLock)
- {
- StreamingChoice existingStreamingChoice = _streamingChoices
- .FirstOrDefault(choice => choice.Index == choiceFromSse.Index);
- if (existingStreamingChoice == null)
- {
- StreamingChoice newStreamingChoice = new(choiceFromSse);
- _streamingChoices.Add(newStreamingChoice);
- _updateAvailableEvent.Set();
- }
- else
- {
- existingStreamingChoice.UpdateFromEventStreamChoice(choiceFromSse);
- }
- }
- }
- }
- }
- catch (Exception pumpException)
- {
- _pumpException = pumpException;
- }
- finally
- {
- lock (_streamingChoicesLock)
- {
- // If anything went wrong and a StreamingChoice didn't naturally determine it was complete
- // based on a non-null finish reason, ensure that nothing's left incomplete (and potentially
- // hanging!) now.
- foreach (StreamingChoice streamingChoice in _streamingChoices)
- {
- streamingChoice.EnsureFinishStreaming(_pumpException);
- }
- }
- _streamingTaskComplete = true;
- _updateAvailableEvent.Set();
- }
- });
+ internal StreamingCompletions(IEnumerable completions)
+ {
+ _cachedCompletions.AddRange(completions);
}
- public async IAsyncEnumerable GetChoicesStreaming(
+ public async IAsyncEnumerable EnumerateCompletions(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- bool isFinalIndex = false;
- for (int i = 0; !isFinalIndex && !cancellationToken.IsCancellationRequested; i++)
+ if (_cachedCompletions.Any())
{
- bool doneWaiting = false;
- while (!doneWaiting)
+ foreach (Completions completions in _cachedCompletions)
{
- lock (_streamingChoicesLock)
- {
- doneWaiting = _streamingTaskComplete || i < _streamingChoices.Count;
- isFinalIndex = _streamingTaskComplete && i >= _streamingChoices.Count - 1;
- }
-
- if (!doneWaiting)
- {
- await _updateAvailableEvent.WaitAsync(cancellationToken).ConfigureAwait(false);
- }
+ yield return completions;
}
+ yield break;
+ }
- if (_pumpException != null)
+ // Open clarification points:
+ // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested?
+ // - What should happen if the stream is held too long and it times out?
+ // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently?
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
+ if (sseEvent == null)
{
- throw _pumpException;
+ _baseResponse.ContentStream?.Dispose();
+ break;
}
- StreamingChoice newChoice = null;
- lock (_streamingChoicesLock)
- {
- if (i < _streamingChoices.Count)
- {
- newChoice = _streamingChoices[i];
- }
- }
+ ReadOnlyMemory name = sseEvent.Value.FieldName;
+ if (!name.Span.SequenceEqual("data".AsSpan()))
+ throw new InvalidDataException();
- if (newChoice != null)
+ ReadOnlyMemory value = sseEvent.Value.FieldValue;
+ if (value.Span.SequenceEqual("[DONE]".AsSpan()))
{
- yield return newChoice;
+ _baseResponse.ContentStream?.Dispose();
+ break;
}
- }
- }
- internal StreamingCompletions(
- Completions baseCompletions = null,
- IList streamingChoices = null)
- {
- _baseCompletions.Add(baseCompletions);
- _streamingChoices = streamingChoices;
- _streamingTaskComplete = true;
+ JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
+ Completions sseCompletions = Completions.DeserializeCompletions(sseMessageJson.RootElement);
+ _cachedCompletions.Add(sseCompletions);
+ yield return sseCompletions;
+ }
+ _baseResponse?.ContentStream?.Dispose();
}
public void Dispose()
@@ -200,13 +94,5 @@ protected virtual void Dispose(bool disposing)
_disposedValue = true;
}
}
-
- private T GetLocked(Func func)
- {
- lock (_baseCompletionsLock)
- {
- return func.Invoke();
- }
- }
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
index 8f913da5a4dfd..c72a2abcfdead 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
@@ -186,13 +186,13 @@ public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget servic
},
};
- Response response = await client.GetChatCompletionsStreamingAsync2(
+ Response response = await client.GetChatCompletionsStreamingAsync(
deploymentOrModelName,
requestOptions);
Assert.That(response, Is.Not.Null);
Assert.That(response.Value, Is.Not.Null);
- using StreamingChatCompletions2 streamingChatCompletions = response.Value;
+ using StreamingChatCompletions streamingChatCompletions = response.Value;
ChatRole? streamedRole = null;
IEnumerable azureContextMessages = null;
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
index c7d3289706775..1f3a2d72ffb3b 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
@@ -154,11 +154,11 @@ public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceT
MaxTokens = 512,
};
- Response response
- = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions);
+ Response response
+ = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
Assert.That(response, Is.Not.Null);
- using StreamingChatCompletions2 streamingChatCompletions = response.Value;
+ using StreamingChatCompletions streamingChatCompletions = response.Value;
ChatRole? streamedRole = default;
string functionName = null;
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
index 5b8ba0ba9f42e..e5eabcaa1a50c 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
@@ -293,11 +293,11 @@ public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTar
},
MaxTokens = 512,
};
- Response streamingResponse
- = await client.GetChatCompletionsStreamingAsync2(deploymentOrModelName, requestOptions);
+ Response streamingResponse
+ = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
Assert.That(streamingResponse, Is.Not.Null);
- using StreamingChatCompletions2 streamingChatCompletions = streamingResponse.Value;
- Assert.That(streamingChatCompletions, Is.InstanceOf());
+ using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value;
+ Assert.That(streamingChatCompletions, Is.InstanceOf());
StringBuilder contentBuilder = new StringBuilder();
bool gotRole = false;
@@ -440,9 +440,6 @@ public async Task TokenCutoff(OpenAIClientServiceTarget serviceTarget)
[TestCase(OpenAIClientServiceTarget.NonAzure)]
public async Task StreamingCompletions(OpenAIClientServiceTarget serviceTarget)
{
- // Temporary test note: at the time of authoring, content filter results aren't included for completions
- // with the latest 2023-07-01-preview service API version. We'll manually configure one version behind
- // pending the latest version having these results enabled.
OpenAIClient client = GetTestClient(serviceTarget);
string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.LegacyCompletions);
@@ -467,51 +464,68 @@ public async Task StreamingCompletions(OpenAIClientServiceTarget serviceTarget)
// implement IDisposable or otherwise ensure that an `IDisposable` underlying `.Value` is disposed.
using StreamingCompletions responseValue = response.Value;
- int originallyEnumeratedChoices = 0;
- var originallyEnumeratedTextParts = new List>();
+ Dictionary promptFilterResultsByPromptIndex = new();
+ Dictionary finishReasonsByChoiceIndex = new();
+ Dictionary textBuildersByChoiceIndex = new();
- await foreach (StreamingChoice choice in responseValue.GetChoicesStreaming())
+ await foreach (Completions streamingCompletions in responseValue.EnumerateCompletions())
{
- List textPartsForChoice = new();
- StringBuilder choiceTextBuilder = new();
- await foreach (string choiceTextPart in choice.GetTextStreaming())
+ if (streamingCompletions.PromptFilterResults is not null)
{
- choiceTextBuilder.Append(choiceTextPart);
- textPartsForChoice.Add(choiceTextPart);
+ foreach (PromptFilterResult promptFilterResult in streamingCompletions.PromptFilterResults)
+ {
+ // When providing multiple prompts, the filter results may arrive across separate messages and
+ // the payload array index may differ from the in-data index property
+ AssertExpectedContentFilterResults(promptFilterResult.ContentFilterResults, serviceTarget);
+ promptFilterResultsByPromptIndex[promptFilterResult.PromptIndex] = promptFilterResult;
+ }
+ }
+ foreach (Choice choice in streamingCompletions.Choices)
+ {
+ if (choice.FinishReason.HasValue)
+ {
+ // Each choice should only receive a single finish reason and, in this case, it should be
+ // 'stop'
+ Assert.That(!finishReasonsByChoiceIndex.ContainsKey(choice.Index));
+ Assert.That(choice.FinishReason.Value == CompletionsFinishReason.Stopped);
+ finishReasonsByChoiceIndex[choice.Index] = choice.FinishReason.Value;
+ }
+
+ // The 'Text' property of streamed Completions will only contain the incremental update for a
+ // choice; these should be appended to each other to form the complete text
+ if (!textBuildersByChoiceIndex.ContainsKey(choice.Index))
+ {
+ textBuildersByChoiceIndex[choice.Index] = new();
+ }
+ textBuildersByChoiceIndex[choice.Index].Append(choice.Text);
+
+ Assert.That(choice.LogProbabilityModel, Is.Not.Null);
+
+ if (!string.IsNullOrEmpty(choice.Text))
+ {
+ // Content filter results are only populated when content (text) is present
+ AssertExpectedContentFilterResults(choice.ContentFilterResults, serviceTarget);
+ }
}
- Assert.That(choiceTextBuilder.ToString(), Is.Not.Null.Or.Empty);
- Assert.That(choice.LogProbabilityModel, Is.Not.Null);
- AssertExpectedContentFilterResults(choice.ContentFilterResults, serviceTarget);
- originallyEnumeratedChoices++;
- originallyEnumeratedTextParts.Add(textPartsForChoice);
}
- // Note: these top-level values *are likely not yet populated* until *after* at least one streaming
- // choice has arrived.
- Assert.That(response.GetRawResponse(), Is.Not.Null.Or.Empty);
- Assert.That(responseValue.Id, Is.Not.Null.Or.Empty);
- Assert.That(responseValue.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1))));
- Assert.That(responseValue.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7)));
+ int expectedPromptCount = requestOptions.Prompts.Count;
+ int expectedChoiceCount = expectedPromptCount * (requestOptions.ChoicesPerPrompt ?? 1);
- AssertExpectedPromptFilterResults(
- responseValue.PromptFilterResults,
- serviceTarget,
- expectedCount: requestOptions.Prompts.Count * (requestOptions.ChoicesPerPrompt ?? 1));
-
- // Validate stability of enumeration (non-cancelled case)
- IReadOnlyList secondPassChoices = await GetBlockingListFromIAsyncEnumerable(
- responseValue.GetChoicesStreaming());
- Assert.AreEqual(originallyEnumeratedChoices, secondPassChoices.Count);
- for (int i = 0; i < secondPassChoices.Count; i++)
+ if (serviceTarget == OpenAIClientServiceTarget.Azure)
{
- IReadOnlyList secondPassTextParts = await GetBlockingListFromIAsyncEnumerable(
- secondPassChoices[i].GetTextStreaming());
- Assert.AreEqual(originallyEnumeratedTextParts[i].Count, secondPassTextParts.Count);
- for (int j = 0; j < originallyEnumeratedTextParts[i].Count; j++)
+ for (int i = 0; i < expectedPromptCount; i++)
{
- Assert.AreEqual(originallyEnumeratedTextParts[i][j], secondPassTextParts[j]);
+ Assert.That(promptFilterResultsByPromptIndex.ContainsKey(i));
}
}
+
+ for (int i = 0; i < expectedChoiceCount; i++)
+ {
+ Assert.That(finishReasonsByChoiceIndex.ContainsKey(i));
+ Assert.That(textBuildersByChoiceIndex.ContainsKey(i));
+ Assert.That(textBuildersByChoiceIndex[i].ToString(), Is.Not.Null.Or.Empty);
+ }
}
[RecordedTest]
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
index bd0ee5d1db4e1..8bee03f306da5 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample04_StreamingChat.cs
@@ -11,39 +11,6 @@ namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class StreamingChat
{
- //[Test]
- //[Ignore("Only verifying that the sample builds")]
- //public async Task StreamingChatWithNonAzureOpenAI()
- //{
- // #region Snippet:StreamChatMessages
- // string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com";
- // var client = new OpenAIClient(nonAzureOpenAIApiKey, new OpenAIClientOptions());
- // var chatCompletionsOptions = new ChatCompletionsOptions()
- // {
- // Messages =
- // {
- // new ChatMessage(ChatRole.System, "You are a helpful assistant. You will talk like a pirate."),
- // new ChatMessage(ChatRole.User, "Can you help me?"),
- // new ChatMessage(ChatRole.Assistant, "Arrrr! Of course, me hearty! What can I do for ye?"),
- // new ChatMessage(ChatRole.User, "What's the best way to train a parrot?"),
- // }
- // };
-
- // Response response = await client.GetChatCompletionsStreamingAsync(
- // deploymentOrModelName: "gpt-3.5-turbo",
- // chatCompletionsOptions);
- // using StreamingChatCompletions streamingChatCompletions = response.Value;
-
- // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
- // {
- // await foreach (ChatMessage message in choice.GetMessageStreaming())
- // {
- // Console.Write(message.Content);
- // }
- // Console.WriteLine();
- // }
- // #endregion
- //}
[Test]
[Ignore("Only verifying that the sample builds")]
public async Task StreamingChatWithNonAzureOpenAI()
@@ -62,10 +29,10 @@ public async Task StreamingChatWithNonAzureOpenAI()
}
};
- Response response = await client.GetChatCompletionsStreamingAsync2(
+ Response response = await client.GetChatCompletionsStreamingAsync(
deploymentOrModelName: "gpt-3.5-turbo",
chatCompletionsOptions);
- using StreamingChatCompletions2 streamingChatCompletions = response.Value;
+ using StreamingChatCompletions streamingChatCompletions = response.Value;
await foreach (StreamingChatCompletionsUpdate chatUpdate
in streamingChatCompletions.EnumerateChatUpdates())
From 9d2b330031ffac79d01c53bae330bfc5277546eb Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Wed, 18 Oct 2023 17:13:27 -0700
Subject: [PATCH 05/19] comments, tests, and other cleanup
---
.../src/Custom/StreamingChatCompletions.cs | 54 +++++++-
.../Custom/StreamingChatCompletionsUpdate.cs | 120 +++++++++++++++++-
.../src/Custom/StreamingCompletions.cs | 61 ++++++++-
.../tests/ChatFunctionsTests.cs | 53 +-------
.../OpenAIInferenceModelFactoryTests.cs.cs | 34 +++++
.../tests/OpenAIInferenceTests.cs | 52 +-------
6 files changed, 255 insertions(+), 119 deletions(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
index fe868d29f6e60..4da4215ef7751 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
@@ -12,13 +12,24 @@
namespace Azure.AI.OpenAI
{
+ ///
+ /// Represents a streaming source of instances that allow an
+ /// application to incrementally use and display text and other information as it becomes available during a
+ /// generation operation.
+ ///
public class StreamingChatCompletions : IDisposable
{
private readonly Response _baseResponse;
private readonly SseReader _baseResponseReader;
- private List _cachedUpdates;
+ private readonly List _cachedUpdates;
private bool _disposedValue;
+ ///
+ /// Creates a new instance of based on a REST response stream.
+ ///
+ ///
+ /// The from which server-sent events will be consumed.
+ ///
internal StreamingChatCompletions(Response response)
{
_baseResponse = response;
@@ -26,11 +37,42 @@ internal StreamingChatCompletions(Response response)
_cachedUpdates = new();
}
+ ///
+ /// Creates a new instance of based on an existing set of
+ /// instances. Typically used for tests or mocking.
+ ///
+ ///
+ /// An existing collection of instances to use. Update enumeration
+ /// will yield these instances instead of instances derived from a response stream.
+ ///
internal StreamingChatCompletions(IEnumerable updates)
{
- _cachedUpdates.AddRange(updates);
+ _cachedUpdates = new(updates);
}
+ ///
+ /// Returns an asynchronous enumeration of instances as they
+ /// become available via the associated network response or other data source.
+ ///
+ ///
+ /// This collection may be iterated over using the "await foreach" statement.
+ ///
+ ///
+ ///
+ /// An optional used to end enumeration before it would normally complete.
+ ///
+ ///
+ /// Cancellation will immediately close any underlying network response stream and may consequently limit
+ /// incurred token generation and associated cost.
+ ///
+ ///
+ ///
+ /// An asynchronous enumeration of instances.
+ ///
+ ///
+ /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not
+ /// prefixed with the expected 'data:' tag.
+ ///
public async IAsyncEnumerable EnumerateChatUpdates(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
@@ -52,7 +94,7 @@ public async IAsyncEnumerable EnumerateChatUpdat
SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
if (sseEvent == null)
{
- _baseResponse.ContentStream?.Dispose();
+ _baseResponse?.ContentStream?.Dispose();
break;
}
@@ -63,7 +105,7 @@ public async IAsyncEnumerable EnumerateChatUpdat
ReadOnlyMemory value = sseEvent.Value.FieldValue;
if (value.Span.SequenceEqual("[DONE]".AsSpan()))
{
- _baseResponse.ContentStream?.Dispose();
+ _baseResponse?.ContentStream?.Dispose();
break;
}
@@ -78,19 +120,21 @@ in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sse
_baseResponse?.ContentStream?.Dispose();
}
+ ///
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
+ ///
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
- _baseResponseReader.Dispose();
+ _baseResponseReader?.Dispose();
}
_disposedValue = true;
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
index c16e0d20f2e5d..f28ddcf5f94fd 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
@@ -2,34 +2,144 @@
// Licensed under the MIT License.
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Threading;
-using Azure.Core;
namespace Azure.AI.OpenAI
{
public partial class StreamingChatCompletionsUpdate
{
+ ///
+ /// Gets a unique identifier associated with this streamed Chat Completions response.
+ ///
+ ///
+ ///
+ /// Corresponds to $.id in the underlying REST schema.
+ ///
+ /// When using Azure OpenAI, note that the values of and may not be
+ /// populated until the first containing role, content, or
+ /// function information.
+ ///
public string Id { get; }
+ ///
+ /// Gets the first timestamp associated with generation activity for this completions response,
+ /// represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970.
+ ///
+ ///
+ ///
+ /// Corresponds to $.created in the underlying REST schema.
+ ///
+ /// When using Azure OpenAI, note that the values of and may not be
+ /// populated until the first containing role, content, or
+ /// function information.
+ ///
public DateTimeOffset Created { get; }
+ ///
+ /// Gets the associated with this update.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].delta.role in the underlying REST schema.
+ ///
+ /// assignment typically occurs in a single update across a streamed Chat Completions
+ /// choice and the value should be considered to be persist for all subsequent updates without a
+ /// that bear the same .
+ ///
public ChatRole? Role { get; }
+ ///
+ /// Gets the content fragment associated with this update.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].delta.content in the underlying REST schema.
+ ///
+ /// Each update contains only a small number of tokens. When presenting or reconstituting a full, streamed
+ /// response, all values for the same should be
+ /// combined.
+ ///
public string ContentUpdate { get; }
+ ///
+ /// Gets the name of a function to be called.
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].delta.function_call.name in the underlying REST schema.
+ ///
public string FunctionName { get; }
+ ///
+ /// Gets a function arguments fragment associated with this update.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].delta.function_call.arguments in the underlying REST schema.
+ ///
+ ///
+ ///
+ /// Each update contains only a small number of tokens. When presenting or reconstituting a full, streamed
+ /// arguments body, all values for the same
+ /// should be combined.
+ ///
+ ///
+ ///
+ /// As is the case for non-streaming , the content provided for function
+ /// arguments is not guaranteed to be well-formed JSON or to contain expected data. Use of function arguments
+ /// should validate before consumption.
+ ///
+ ///
public string FunctionArgumentsUpdate { get; }
+ ///
+ /// Gets an optional name associated with the role of the streamed Chat Completion, typically as previously
+ /// specified in a system message.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].delta.name in the underlying REST schema.
+ ///
+ ///
public string AuthorName { get; }
+ ///
+ /// Gets the associated with this update.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].finish_reason in the underlying REST schema.
+ ///
+ ///
+ /// assignment typically appears in the final streamed update message associated
+ /// with a choice.
+ ///
+ ///
public CompletionsFinishReason? FinishReason { get; }
+ ///
+ /// Gets the choice index associated with this streamed update.
+ ///
+ ///
+ ///
+ /// Corresponds to e.g. $.choices[0].index in the underlying REST schema.
+ ///
+ ///
+ /// Unless a value greater than one was provided to ('n' in
+ /// REST), only one choice will be generated. In that case, this value will always be 0 and may not need to be
+ /// considered.
+ ///
+ ///
+ /// When providing a value greater than one to , this index
+ /// identifies which logical response the payload is associated with. In the event that a single underlying
+ /// server-sent event contains multiple choices, multiple instances of
+ /// will be created.
+ ///
+ ///
public int? ChoiceIndex { get; }
+ ///
+ /// Gets additional data associated with this update that is specific to use with Azure OpenAI and that
+ /// service's extended capabilities.
+ ///
public AzureChatExtensionsMessageContext AzureExtensionsContext { get; }
internal StreamingChatCompletionsUpdate(
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
index fa032a2f28f50..ca1e98923f136 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
@@ -8,18 +8,28 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
-using System.Threading.Tasks;
using Azure.Core.Sse;
namespace Azure.AI.OpenAI
{
+ ///
+ /// Represents a streaming source of instances that allow an application to incrementally
+ /// use and display text and other information as it becomes available during a generation operation.
+ ///
public class StreamingCompletions : IDisposable
{
private readonly Response _baseResponse;
private readonly SseReader _baseResponseReader;
- private List _cachedCompletions;
+ private readonly List _cachedCompletions;
private bool _disposedValue;
+ ///
+ /// Creates a new instance of that will generate updates from a
+ /// service response stream.
+ ///
+ ///
+ /// The instance from which server-sent events will be parsed.
+ ///
internal StreamingCompletions(Response response)
{
_baseResponse = response;
@@ -27,11 +37,50 @@ internal StreamingCompletions(Response response)
_cachedCompletions = new();
}
+ ///
+ /// Creates a new instance of that will yield instances from a provided set
+ /// of instances. Typically used for tests and mocking.
+ ///
+ ///
+ /// An existing set of instances to use.
+ ///
internal StreamingCompletions(IEnumerable completions)
{
- _cachedCompletions.AddRange(completions);
+ _cachedCompletions = new(completions);
}
+ ///
+ /// Returns an asynchronous enumeration of instances as they
+ /// become available via the associated network response or other data source.
+ ///
+ ///
+ ///
+ /// This collection may be iterated over using the "await foreach" statement.
+ ///
+ ///
+ /// Note that, in contrast to , streamed share
+ /// the same underlying representation and object model as their non-streamed counterpart. The
+ /// property on a streamed instanced will only contain a
+ /// small number of new tokens and must be combined with other values to represent
+ /// a full, reconsistuted message.
+ ///
+ ///
+ ///
+ ///
+ /// An optional used to end enumeration before it would normally complete.
+ ///
+ ///
+ /// Cancellation will immediately close any underlying network response stream and may consequently limit
+ /// incurred token generation and associated cost.
+ ///
+ ///
+ ///
+ /// An asynchronous enumeration of instances.
+ ///
+ ///
+ /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not
+ /// prefixed with the expected 'data:' tag.
+ ///
public async IAsyncEnumerable EnumerateCompletions(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
@@ -53,7 +102,7 @@ public async IAsyncEnumerable EnumerateCompletions(
SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
if (sseEvent == null)
{
- _baseResponse.ContentStream?.Dispose();
+ _baseResponse?.ContentStream?.Dispose();
break;
}
@@ -64,7 +113,7 @@ public async IAsyncEnumerable EnumerateCompletions(
ReadOnlyMemory value = sseEvent.Value.FieldValue;
if (value.Span.SequenceEqual("[DONE]".AsSpan()))
{
- _baseResponse.ContentStream?.Dispose();
+ _baseResponse?.ContentStream?.Dispose();
break;
}
@@ -88,7 +137,7 @@ protected virtual void Dispose(bool disposing)
{
if (disposing)
{
- _baseResponseReader.Dispose();
+ _baseResponseReader?.Dispose();
}
_disposedValue = true;
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
index 1f3a2d72ffb3b..fd7b345ef854a 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatFunctionsTests.cs
@@ -84,61 +84,10 @@ public async Task SimpleFunctionCallWorks(OpenAIClientServiceTarget serviceTarge
Assert.That(followupResponse.Value.Choices[0].Message.Content, Is.Not.Null.Or.Empty);
}
- //[RecordedTest]
- //[TestCase(OpenAIClientServiceTarget.Azure)]
- //[TestCase(OpenAIClientServiceTarget.NonAzure)]
- //public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget)
- //{
- // OpenAIClient client = GetTestClient(serviceTarget);
- // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions);
-
- // var requestOptions = new ChatCompletionsOptions()
- // {
- // Functions = { s_futureTemperatureFunction },
- // Messages =
- // {
- // new ChatMessage(ChatRole.System, "You are a helpful assistant."),
- // new ChatMessage(ChatRole.User, "What should I wear in Honolulu next Thursday?"),
- // },
- // MaxTokens = 512,
- // };
-
- // Response response
- // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
- // Assert.That(response, Is.Not.Null);
-
- // using StreamingChatCompletions streamingChatCompletions = response.Value;
-
- // ChatRole streamedRole = default;
- // string functionName = default;
- // StringBuilder argumentsBuilder = new();
-
- // await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoicesStreaming())
- // {
- // await foreach (ChatMessage message in choice.GetMessageStreaming())
- // {
- // if (message.Role != default)
- // {
- // streamedRole = message.Role;
- // }
- // if (message.FunctionCall?.Name != null)
- // {
- // Assert.That(functionName, Is.Null.Or.Empty);
- // functionName = message.FunctionCall.Name;
- // }
- // argumentsBuilder.Append(message.FunctionCall?.Arguments ?? string.Empty);
- // }
- // }
-
- // Assert.That(streamedRole, Is.EqualTo(ChatRole.Assistant));
- // Assert.That(functionName, Is.EqualTo(s_futureTemperatureFunction.Name));
- // Assert.That(argumentsBuilder.Length, Is.GreaterThan(0));
- //}
-
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
[TestCase(OpenAIClientServiceTarget.NonAzure)]
- public async Task StreamingFunctionCallWorks2(OpenAIClientServiceTarget serviceTarget)
+ public async Task StreamingFunctionCallWorks(OpenAIClientServiceTarget serviceTarget)
{
OpenAIClient client = GetTestClient(serviceTarget);
string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(serviceTarget, OpenAIClientScenario.ChatCompletions);
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs
index ea9e18667a16a..30401cad8ce50 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceModelFactoryTests.cs.cs
@@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
using NUnit.Framework;
namespace Azure.AI.OpenAI.Tests
@@ -180,5 +182,37 @@ public void TestAudioTranslation()
Assert.That(audioTranslation.Segments, Is.Not.Null.Or.Empty);
Assert.That(audioTranslation.Segments.Count, Is.EqualTo(1));
}
+
+ [Test]
+ public async Task TestStreamingChatCompletions()
+ {
+ const string expectedId = "expected-id-value";
+
+ StreamingChatCompletionsUpdate firstUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate(
+ expectedId,
+ DateTime.Now,
+ role: ChatRole.Assistant,
+ contentUpdate: "hello");
+ StreamingChatCompletionsUpdate secondUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate(
+ expectedId,
+ DateTime.Now,
+ contentUpdate: " world");
+ StreamingChatCompletionsUpdate thirdUpdate = AzureOpenAIModelFactory.StreamingChatCompletionsUpdate(
+ expectedId,
+ DateTime.Now,
+ finishReason: CompletionsFinishReason.Stopped);
+
+ using StreamingChatCompletions streamingChatCompletions = AzureOpenAIModelFactory.StreamingChatCompletions(
+ new[] { firstUpdate, secondUpdate, thirdUpdate });
+
+ StringBuilder contentBuilder = new();
+ await foreach (StreamingChatCompletionsUpdate update in streamingChatCompletions.EnumerateChatUpdates())
+ {
+ Assert.That(update.Id == expectedId);
+ Assert.That(update.Created > new DateTimeOffset(new DateTime(2023, 1, 1)));
+ contentBuilder.Append(update.ContentUpdate);
+ }
+ Assert.That(contentBuilder.ToString() == "hello world");
+ }
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
index e5eabcaa1a50c..951580c9d7c3f 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
@@ -223,60 +223,10 @@ string deploymentOrModelName
AssertExpectedContentFilterResults(firstChoice.ContentFilterResults, serviceTarget);
}
- //[RecordedTest]
- //[TestCase(OpenAIClientServiceTarget.Azure)]
- //[TestCase(OpenAIClientServiceTarget.NonAzure)]
- //public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget)
- //{
- // OpenAIClient client = GetTestClient(serviceTarget);
- // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
- // serviceTarget,
- // OpenAIClientScenario.ChatCompletions);
- // var requestOptions = new ChatCompletionsOptions()
- // {
- // Messages =
- // {
- // new ChatMessage(ChatRole.System, "You are a helpful assistant."),
- // new ChatMessage(ChatRole.User, "Can you help me?"),
- // new ChatMessage(ChatRole.Assistant, "Of course! What do you need help with?"),
- // new ChatMessage(ChatRole.User, "What temperature should I bake pizza at?"),
- // },
- // MaxTokens = 512,
- // };
- // Response streamingResponse
- // = await client.GetChatCompletionsStreamingAsync(deploymentOrModelName, requestOptions);
- // Assert.That(streamingResponse, Is.Not.Null);
- // using StreamingChatCompletions streamingChatCompletions = streamingResponse.Value;
- // Assert.That(streamingChatCompletions, Is.InstanceOf());
-
- // int totalMessages = 0;
-
- // await foreach (StreamingChatChoice streamingChoice in streamingChatCompletions.GetChoicesStreaming())
- // {
- // Assert.That(streamingChoice, Is.Not.Null);
- // await foreach (ChatMessage streamingMessage in streamingChoice.GetMessageStreaming())
- // {
- // Assert.That(streamingMessage.Role, Is.EqualTo(ChatRole.Assistant));
- // totalMessages++;
- // }
- // AssertExpectedContentFilterResults(streamingChoice.ContentFilterResults, serviceTarget);
- // }
-
- // Assert.That(totalMessages, Is.GreaterThan(1));
-
- // // Note: these top-level values *are likely not yet populated* until *after* at least one streaming
- // // choice has arrived.
- // Assert.That(streamingResponse.GetRawResponse(), Is.Not.Null.Or.Empty);
- // Assert.That(streamingChatCompletions.Id, Is.Not.Null.Or.Empty);
- // Assert.That(streamingChatCompletions.Created, Is.GreaterThan(new DateTimeOffset(new DateTime(2023, 1, 1))));
- // Assert.That(streamingChatCompletions.Created, Is.LessThan(DateTimeOffset.UtcNow.AddDays(7)));
- // AssertExpectedPromptFilterResults(streamingChatCompletions.PromptFilterResults, serviceTarget, (requestOptions.ChoiceCount ?? 1));
- //}
-
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
[TestCase(OpenAIClientServiceTarget.NonAzure)]
- public async Task StreamingChatCompletions2(OpenAIClientServiceTarget serviceTarget)
+ public async Task StreamingChatCompletions(OpenAIClientServiceTarget serviceTarget)
{
OpenAIClient client = GetTestClient(serviceTarget);
string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
From 2e3f43fe8dea77730ded99bbbfd50620d2f6b652 Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Wed, 18 Oct 2023 17:15:59 -0700
Subject: [PATCH 06/19] one orphaned test comment cleanup
---
.../tests/AzureChatExtensionsTests.cs | 62 -------------------
1 file changed, 62 deletions(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
index c72a2abcfdead..96b1c56c7bff4 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/AzureChatExtensionsTests.cs
@@ -89,68 +89,6 @@ public async Task BasicSearchExtensionWorks(
Assert.That(context.Messages.First().Content.Contains("citations"));
}
- //[RecordedTest]
- //[TestCase(OpenAIClientServiceTarget.Azure)]
- //public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget)
- //{
- // OpenAIClient client = GetTestClient(serviceTarget);
- // string deploymentOrModelName = OpenAITestBase.GetDeploymentOrModelName(
- // serviceTarget,
- // OpenAIClientScenario.ChatCompletions);
-
- // var requestOptions = new ChatCompletionsOptions()
- // {
- // Messages =
- // {
- // new ChatMessage(ChatRole.User, "What does PR complete mean?"),
- // },
- // MaxTokens = 512,
- // AzureExtensionsOptions = new()
- // {
- // Extensions =
- // {
- // new AzureChatExtensionConfiguration()
- // {
- // Type = "AzureCognitiveSearch",
- // Parameters = BinaryData.FromObjectAsJson(new
- // {
- // Endpoint = "https://openaisdktestsearch.search.windows.net",
- // IndexName = "openai-test-index-carbon-wiki",
- // Key = GetCognitiveSearchApiKey().Key,
- // },
- // new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }),
- // },
- // }
- // },
- // };
-
- // Response response = await client.GetChatCompletionsStreamingAsync(
- // deploymentOrModelName,
- // requestOptions);
- // Assert.That(response, Is.Not.Null);
- // Assert.That(response.Value, Is.Not.Null);
-
- // using StreamingChatCompletions streamingChatCompletions = response.Value;
-
- // int choiceCount = 0;
- // List messageChunks = new();
-
- // await foreach (StreamingChatChoice streamingChatChoice in response.Value.GetChoicesStreaming())
- // {
- // choiceCount++;
- // await foreach (ChatMessage chatMessage in streamingChatChoice.GetMessageStreaming())
- // {
- // messageChunks.Add(chatMessage);
- // }
- // }
-
- // Assert.That(choiceCount, Is.EqualTo(1));
- // Assert.That(messageChunks, Is.Not.Null.Or.Empty);
- // Assert.That(messageChunks.Any(chunk => chunk.AzureExtensionsContext != null && chunk.AzureExtensionsContext.Messages.Any(m => m.Role == ChatRole.Tool)));
- // //Assert.That(messageChunks.Any(chunk => chunk.Role == ChatRole.Assistant));
- // Assert.That(messageChunks.Any(chunk => !string.IsNullOrWhiteSpace(chunk.Content)));
- //}
-
[RecordedTest]
[TestCase(OpenAIClientServiceTarget.Azure)]
public async Task StreamingSearchExtensionWorks(OpenAIClientServiceTarget serviceTarget)
From 91333589f303a9ff24775dfcc052f60d6b7b2b46 Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Wed, 18 Oct 2023 17:33:53 -0700
Subject: [PATCH 07/19] xml comment fix
---
.../Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs | 1 -
sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs | 1 +
2 files changed, 1 insertion(+), 1 deletion(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
index 4da4215ef7751..ff6a184e318e3 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
@@ -127,7 +127,6 @@ public void Dispose()
GC.SuppressFinalize(this);
}
- ///
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
index ca1e98923f136..a26baeb5631a7 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
@@ -125,6 +125,7 @@ public async IAsyncEnumerable EnumerateCompletions(
_baseResponse?.ContentStream?.Dispose();
}
+ ///
public void Dispose()
{
Dispose(disposing: true);
From 6b520eafa509035d64a059bdc539576768e1132e Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Thu, 19 Oct 2023 12:14:47 -0700
Subject: [PATCH 08/19] test assets, making it real
---
sdk/openai/Azure.AI.OpenAI/assets.json | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/assets.json b/sdk/openai/Azure.AI.OpenAI/assets.json
index 9b85d4de7bc32..3aea2fc197d1d 100644
--- a/sdk/openai/Azure.AI.OpenAI/assets.json
+++ b/sdk/openai/Azure.AI.OpenAI/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "net",
"TagPrefix": "net/openai/Azure.AI.OpenAI",
- "Tag": "net/openai/Azure.AI.OpenAI_52e82965d8"
+ "Tag": "net/openai/Azure.AI.OpenAI_2ea3f2ca28"
}
From 5694c900f8341cf0f02c20c85eb768888ed11491 Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Thu, 19 Oct 2023 13:06:41 -0700
Subject: [PATCH 09/19] test assets pt. 2 (re-record functions)
---
sdk/openai/Azure.AI.OpenAI/assets.json | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sdk/openai/Azure.AI.OpenAI/assets.json b/sdk/openai/Azure.AI.OpenAI/assets.json
index 3aea2fc197d1d..3a4237461fe9c 100644
--- a/sdk/openai/Azure.AI.OpenAI/assets.json
+++ b/sdk/openai/Azure.AI.OpenAI/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "net",
"TagPrefix": "net/openai/Azure.AI.OpenAI",
- "Tag": "net/openai/Azure.AI.OpenAI_2ea3f2ca28"
+ "Tag": "net/openai/Azure.AI.OpenAI_86a53abc90"
}
From e1b24f646a893bebeeba3f29ef5da64c576da060 Mon Sep 17 00:00:00 2001
From: Travis Wilson
Date: Fri, 20 Oct 2023 14:44:09 -0700
Subject: [PATCH 10/19] revised pattern using new StreamingResponse
---
sdk/openai/Azure.AI.OpenAI/README.md | 8 +-
.../api/Azure.AI.OpenAI.netstandard2.0.cs | 36 ++---
.../src/Custom/AzureOpenAIModelFactory.cs | 29 +---
.../src/Custom/OpenAIClient.cs | 63 ++++++--
.../src/Custom/StreamingChatCompletions.cs | 143 -----------------
.../Custom/StreamingChatCompletionsUpdate.cs | 13 +-
.../src/Custom/StreamingCompletions.cs | 148 ------------------
.../src/Helpers/SseAsyncEnumerator.cs | 63 ++++++++
.../Azure.AI.OpenAI/src/Helpers/SseReader.cs | 8 +-
.../src/Helpers/StreamingResponse.cs | 70 +++++++++
.../tests/AzureChatExtensionsTests.cs | 15 +-
.../tests/ChatFunctionsTests.cs | 8 +-
.../OpenAIInferenceModelFactoryTests.cs.cs | 41 +++--
.../tests/OpenAIInferenceTests.cs | 31 +---
.../tests/Samples/Sample04_StreamingChat.cs | 10 +-
15 files changed, 256 insertions(+), 430 deletions(-)
delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
delete mode 100644 sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingCompletions.cs
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Helpers/SseAsyncEnumerator.cs
create mode 100644 sdk/openai/Azure.AI.OpenAI/src/Helpers/StreamingResponse.cs
diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md
index cd8a7d002be92..0fd3e91de72be 100644
--- a/sdk/openai/Azure.AI.OpenAI/README.md
+++ b/sdk/openai/Azure.AI.OpenAI/README.md
@@ -206,13 +206,9 @@ var chatCompletionsOptions = new ChatCompletionsOptions()
}
};
-Response response = await client.GetChatCompletionsStreamingAsync(
+await foreach (StreamingChatCompletionsUpdate chatUpdate in client.GetChatCompletionsStreaming(
deploymentOrModelName: "gpt-3.5-turbo",
- chatCompletionsOptions);
-using StreamingChatCompletions streamingChatCompletions = response.Value;
-
-await foreach (StreamingChatCompletionsUpdate chatUpdate
- in streamingChatCompletions.EnumerateChatUpdates())
+ chatCompletionsOptions))
{
if (chatUpdate.Role.HasValue)
{
diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
index 7a2dfac1a77bc..c74ba62e240da 100644
--- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
+++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs
@@ -1,3 +1,15 @@
+namespace Azure
+{
+ public partial class StreamingResponse : System.Collections.Generic.IAsyncEnumerable, System.IDisposable
+ {
+ internal StreamingResponse() { }
+ public void Dispose() { }
+ protected virtual void Dispose(bool disposing) { }
+ public System.Collections.Generic.IAsyncEnumerable EnumerateValues() { throw null; }
+ public Azure.Response GetRawResponse() { throw null; }
+ System.Collections.Generic.IAsyncEnumerator System.Collections.Generic.IAsyncEnumerable.GetAsyncEnumerator(System.Threading.CancellationToken cancellationToken) { throw null; }
+ }
+}
namespace Azure.AI.OpenAI
{
public partial class AudioTranscription
@@ -206,9 +218,7 @@ public static partial class AzureOpenAIModelFactory
public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable data = null) { throw null; }
public static Azure.AI.OpenAI.ImageLocation ImageLocation(System.Uri url = null) { throw null; }
public static Azure.AI.OpenAI.PromptFilterResult PromptFilterResult(int promptIndex = 0, Azure.AI.OpenAI.ContentFilterResults contentFilterResults = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingChatCompletions StreamingChatCompletions(System.Collections.Generic.IEnumerable updates) { throw null; }
public static Azure.AI.OpenAI.StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(string id, System.DateTimeOffset created, int? choiceIndex = default(int?), Azure.AI.OpenAI.ChatRole? role = default(Azure.AI.OpenAI.ChatRole?), string authorName = null, string contentUpdate = null, Azure.AI.OpenAI.CompletionsFinishReason? finishReason = default(Azure.AI.OpenAI.CompletionsFinishReason?), string functionName = null, string functionArgumentsUpdate = null, Azure.AI.OpenAI.AzureChatExtensionsMessageContext azureExtensionsContext = null) { throw null; }
- public static Azure.AI.OpenAI.StreamingCompletions StreamingCompletions(System.Collections.Generic.IEnumerable completions) { throw null; }
}
public partial class ChatChoice
{
@@ -479,14 +489,14 @@ public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential tokenCredent
public virtual System.Threading.Tasks.Task> GetAudioTranslationAsync(string deploymentId, Azure.AI.OpenAI.AudioTranslationOptions audioTranslationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetChatCompletions(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetChatCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual Azure.Response GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual Azure.StreamingResponse GetChatCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual System.Threading.Tasks.Task> GetChatCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.ChatCompletionsOptions chatCompletionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetCompletionsAsync(string deploymentOrModelName, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual Azure.Response GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual Azure.StreamingResponse GetCompletionsStreaming(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public virtual System.Threading.Tasks.Task> GetCompletionsStreamingAsync(string deploymentOrModelName, Azure.AI.OpenAI.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetEmbeddings(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task> GetEmbeddingsAsync(string deploymentOrModelName, Azure.AI.OpenAI.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetImageGenerations(Azure.AI.OpenAI.ImageGenerationOptions imageGenerationOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -511,13 +521,6 @@ internal PromptFilterResult() { }
public Azure.AI.OpenAI.ContentFilterResults ContentFilterResults { get { throw null; } }
public int PromptIndex { get { throw null; } }
}
- public partial class StreamingChatCompletions : System.IDisposable
- {
- internal StreamingChatCompletions() { }
- public void Dispose() { }
- protected virtual void Dispose(bool disposing) { }
- public System.Collections.Generic.IAsyncEnumerable EnumerateChatUpdates([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- }
public partial class StreamingChatCompletionsUpdate
{
internal StreamingChatCompletionsUpdate() { }
@@ -532,13 +535,6 @@ internal StreamingChatCompletionsUpdate() { }
public string Id { get { throw null; } }
public Azure.AI.OpenAI.ChatRole? Role { get { throw null; } }
}
- public partial class StreamingCompletions : System.IDisposable
- {
- internal StreamingCompletions() { }
- public void Dispose() { }
- protected virtual void Dispose(bool disposing) { }
- public System.Collections.Generic.IAsyncEnumerable EnumerateCompletions([System.Runtime.CompilerServices.EnumeratorCancellationAttribute] System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
- }
}
namespace Microsoft.Extensions.Azure
{
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
index cc90ea19685df..9cdbf2b7d2767 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIModelFactory.cs
@@ -14,11 +14,11 @@ namespace Azure.AI.OpenAI
public static partial class AzureOpenAIModelFactory
{
/// Initializes a new instance of ChatChoice.
- /// The chat message associated with this chat completions choice
+ /// The chat message associated with this chat completions choice.
/// The ordered index associated with this chat completions choice.
/// The reason that this chat completions choice completed its generated.
- /// For streamed choices, the internal representation of a 'delta' payload
- /// The category annotations for this chat choice's content filtering
+ /// For streamed choices, the internal representation of a 'delta' payload.
+ /// The category annotations for this chat choice's content filtering.
/// A new instance for mocking.
public static ChatChoice ChatChoice(
ChatMessage message = null,
@@ -30,16 +30,6 @@ public static ChatChoice ChatChoice(
return new ChatChoice(message, index, finishReason, deltaMessage, contentFilterResults);
}
- ///
- /// Initializes a new instance of StreamingCompletions for tests and mocking.
- ///
- /// A static set of Completions instances to use for tests and mocking.
- /// A new instance of StreamingCompletions.
- public static StreamingCompletions StreamingCompletions(IEnumerable completions)
- {
- return new StreamingCompletions(completions);
- }
-
public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(
string id,
DateTimeOffset created,
@@ -65,19 +55,6 @@ public static StreamingChatCompletionsUpdate StreamingChatCompletionsUpdate(
azureExtensionsContext);
}
- ///
- /// Initializes a new instance of StreamingChatCompletions for tests and mocking.
- ///
- ///
- /// A static collection of StreamingChatCompletionUpdates to use for tests and mocking.
- ///
- /// A new instance of StreamingChatCompletions
- public static StreamingChatCompletions StreamingChatCompletions(
- IEnumerable updates)
- {
- return new StreamingChatCompletions(updates);
- }
-
/// Initializes a new instance of AudioTranscription.
/// Transcribed text.
/// Language detected in the source audio file.
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
index e46691d42ae88..15fcfc323d50a 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
@@ -4,10 +4,12 @@
#nullable disable
using System;
+using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
+using Azure.Core.Sse;
namespace Azure.AI.OpenAI
{
@@ -239,9 +241,11 @@ public virtual Task> GetCompletionsAsync(
///
/// Service returned a non-success status code.
///
- /// A response that, if the request was successful, includes a instance.
+ /// A response that, if the request was successful, may be asynchronously enumerated for
+ /// instances.
///
- public virtual Response GetCompletionsStreaming(
+ [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")]
+ public virtual StreamingResponse GetCompletionsStreaming(
string deploymentOrModelName,
CompletionsOptions completionsOptions,
CancellationToken cancellationToken = default)
@@ -268,7 +272,12 @@ public virtual Response GetCompletionsStreaming(
context);
message.BufferResponse = false;
Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
- return Response.FromValue(new StreamingCompletions(baseResponse), baseResponse);
+ return new StreamingResponse(
+ baseResponse,
+ SseAsyncEnumerator.EnumerateFromSseStream(
+ baseResponse.ContentStream,
+ Completions.DeserializeCompletions,
+ cancellationToken));
}
catch (Exception e)
{
@@ -278,7 +287,8 @@ public virtual Response GetCompletionsStreaming(
}
///
- public virtual async Task> GetCompletionsStreamingAsync(
+ [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")]
+ public virtual async Task> GetCompletionsStreamingAsync(
string deploymentOrModelName,
CompletionsOptions completionsOptions,
CancellationToken cancellationToken = default)
@@ -306,7 +316,12 @@ public virtual async Task> GetCompletionsStreamin
message.BufferResponse = false;
Response baseResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken)
.ConfigureAwait(false);
- return Response.FromValue(new StreamingCompletions(baseResponse), baseResponse);
+ return new StreamingResponse(
+ baseResponse,
+ SseAsyncEnumerator.EnumerateFromSseStream(
+ baseResponse.ContentStream,
+ Completions.DeserializeCompletions,
+ cancellationToken));
}
catch (Exception e)
{
@@ -420,7 +435,8 @@ public virtual async Task> GetChatCompletionsAsync(
///
/// Service returned a non-success status code.
/// The response returned from the service.
- public virtual Response GetChatCompletionsStreaming(
+ [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")]
+ public virtual StreamingResponse GetChatCompletionsStreaming(
string deploymentOrModelName,
ChatCompletionsOptions chatCompletionsOptions,
CancellationToken cancellationToken = default)
@@ -449,7 +465,12 @@ public virtual Response GetChatCompletionsStreaming(
context);
message.BufferResponse = false;
Response baseResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
- return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
+ return new StreamingResponse(
+ baseResponse,
+ SseAsyncEnumerator.EnumerateFromSseStream(
+ baseResponse.ContentStream,
+ StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates,
+ cancellationToken));
}
catch (Exception e)
{
@@ -459,7 +480,8 @@ public virtual Response GetChatCompletionsStreaming(
}
///
- public virtual async Task> GetChatCompletionsStreamingAsync(
+ [SuppressMessage("Usage", "AZC0015:Unexpected client method return type.")]
+ public virtual async Task> GetChatCompletionsStreamingAsync(
string deploymentOrModelName,
ChatCompletionsOptions chatCompletionsOptions,
CancellationToken cancellationToken = default)
@@ -491,7 +513,12 @@ public virtual async Task> GetChatCompletions
message,
context,
cancellationToken).ConfigureAwait(false);
- return Response.FromValue(new StreamingChatCompletions(baseResponse), baseResponse);
+ return new StreamingResponse(
+ baseResponse,
+ SseAsyncEnumerator.EnumerateFromSseStream(
+ baseResponse.ContentStream,
+ StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates,
+ cancellationToken));
}
catch (Exception e)
{
@@ -709,7 +736,7 @@ public virtual async Task> GetAudioTranscriptionAsy
Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions));
- using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
scope.Start();
audioTranscriptionOptions.InternalNonAzureModelName = deploymentId;
@@ -721,7 +748,8 @@ public virtual async Task> GetAudioTranscriptionAsy
try
{
using HttpMessage message = CreateGetAudioTranscriptionRequest(deploymentId, content, context);
- rawResponse = await _pipeline.ProcessMessageAsync(message, context).ConfigureAwait(false);
+ rawResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken)
+ .ConfigureAwait(false);
}
catch (Exception e)
{
@@ -747,7 +775,7 @@ public virtual Response GetAudioTranscription(string deploym
Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
Argument.AssertNotNull(audioTranscriptionOptions, nameof(audioTranscriptionOptions));
- using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranscription");
scope.Start();
audioTranscriptionOptions.InternalNonAzureModelName = deploymentId;
@@ -759,7 +787,7 @@ public virtual Response GetAudioTranscription(string deploym
try
{
using HttpMessage message = CreateGetAudioTranscriptionRequest(deploymentId, content, context);
- rawResponse = _pipeline.ProcessMessage(message, context);
+ rawResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
}
catch (Exception e)
{
@@ -788,7 +816,7 @@ public virtual async Task> GetAudioTranslationAsync(s
// Custom code: merely linking the deployment ID (== model name) into the request body for non-Azure use
audioTranslationOptions.InternalNonAzureModelName = deploymentId;
- using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation");
scope.Start();
audioTranslationOptions.InternalNonAzureModelName = deploymentId;
@@ -800,7 +828,8 @@ public virtual async Task> GetAudioTranslationAsync(s
try
{
using HttpMessage message = CreateGetAudioTranslationRequest(deploymentId, content, context);
- rawResponse = await _pipeline.ProcessMessageAsync(message, context).ConfigureAwait(false);
+ rawResponse = await _pipeline.ProcessMessageAsync(message, context, cancellationToken)
+ .ConfigureAwait(false);
}
catch (Exception e)
{
@@ -826,7 +855,7 @@ public virtual Response GetAudioTranslation(string deploymentI
Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
Argument.AssertNotNull(audioTranslationOptions, nameof(audioTranslationOptions));
- using var scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation");
+ using DiagnosticScope scope = ClientDiagnostics.CreateScope("OpenAIClient.GetAudioTranslation");
scope.Start();
audioTranslationOptions.InternalNonAzureModelName = deploymentId;
@@ -838,7 +867,7 @@ public virtual Response GetAudioTranslation(string deploymentI
try
{
using HttpMessage message = CreateGetAudioTranslationRequest(deploymentId, content, context);
- rawResponse = _pipeline.ProcessMessage(message, context);
+ rawResponse = _pipeline.ProcessMessage(message, context, cancellationToken);
}
catch (Exception e)
{
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
deleted file mode 100644
index ff6a184e318e3..0000000000000
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletions.cs
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Text.Json;
-using System.Threading;
-using Azure.Core.Sse;
-
-namespace Azure.AI.OpenAI
-{
- ///
- /// Represents a streaming source of instances that allow an
- /// application to incrementally use and display text and other information as it becomes available during a
- /// generation operation.
- ///
- public class StreamingChatCompletions : IDisposable
- {
- private readonly Response _baseResponse;
- private readonly SseReader _baseResponseReader;
- private readonly List _cachedUpdates;
- private bool _disposedValue;
-
- ///
- /// Creates a new instance of based on a REST response stream.
- ///
- ///
- /// The from which server-sent events will be consumed.
- ///
- internal StreamingChatCompletions(Response response)
- {
- _baseResponse = response;
- _baseResponseReader = new SseReader(response.ContentStream);
- _cachedUpdates = new();
- }
-
- ///
- /// Creates a new instance of based on an existing set of
- /// instances. Typically used for tests or mocking.
- ///
- ///
- /// An existing collection of instances to use. Update enumeration
- /// will yield these instances instead of instances derived from a response stream.
- ///
- internal StreamingChatCompletions(IEnumerable updates)
- {
- _cachedUpdates = new(updates);
- }
-
- ///
- /// Returns an asynchronous enumeration of instances as they
- /// become available via the associated network response or other data source.
- ///
- ///
- /// This collection may be iterated over using the "await foreach" statement.
- ///
- ///
- ///
- /// An optional used to end enumeration before it would normally complete.
- ///
- ///
- /// Cancellation will immediately close any underlying network response stream and may consequently limit
- /// incurred token generation and associated cost.
- ///
- ///
- ///
- /// An asynchronous enumeration of instances.
- ///
- ///
- /// Thrown when an underlying data source provided data in an unsupported format, e.g. server-sent events not
- /// prefixed with the expected 'data:' tag.
- ///
- public async IAsyncEnumerable EnumerateChatUpdates(
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- if (_cachedUpdates.Any())
- {
- foreach (StreamingChatCompletionsUpdate update in _cachedUpdates)
- {
- yield return update;
- }
- yield break;
- }
-
- // Open clarification points:
- // - Is it acceptable (or even desirable) that we won't pump the content stream until enumeration is requested?
- // - What should happen if the stream is held too long and it times out?
- // - Should enumeration be concurrency-protected, i.e. possible and correct on two threads concurrently?
- while (!cancellationToken.IsCancellationRequested)
- {
- SseLine? sseEvent = await _baseResponseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false);
- if (sseEvent == null)
- {
- _baseResponse?.ContentStream?.Dispose();
- break;
- }
-
- ReadOnlyMemory name = sseEvent.Value.FieldName;
- if (!name.Span.SequenceEqual("data".AsSpan()))
- throw new InvalidDataException();
-
- ReadOnlyMemory value = sseEvent.Value.FieldValue;
- if (value.Span.SequenceEqual("[DONE]".AsSpan()))
- {
- _baseResponse?.ContentStream?.Dispose();
- break;
- }
-
- JsonDocument sseMessageJson = JsonDocument.Parse(sseEvent.Value.FieldValue);
- foreach (StreamingChatCompletionsUpdate streamingChatCompletionsUpdate
- in StreamingChatCompletionsUpdate.DeserializeStreamingChatCompletionsUpdates(sseMessageJson.RootElement))
- {
- _cachedUpdates.Add(streamingChatCompletionsUpdate);
- yield return streamingChatCompletionsUpdate;
- }
- }
- _baseResponse?.ContentStream?.Dispose();
- }
-
- ///
- public void Dispose()
- {
- Dispose(disposing: true);
- GC.SuppressFinalize(this);
- }
-
- protected virtual void Dispose(bool disposing)
- {
- if (!_disposedValue)
- {
- if (disposing)
- {
- _baseResponseReader?.Dispose();
- }
-
- _disposedValue = true;
- }
- }
- }
-}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
index f28ddcf5f94fd..7232ca9845853 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/StreamingChatCompletionsUpdate.cs
@@ -2,9 +2,18 @@
// Licensed under the MIT License.
using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using Azure.Core.Sse;
namespace Azure.AI.OpenAI
{
+ ///
+ /// Represents an incremental update to a streamed Chat Completions response.
+ ///
public partial class StreamingChatCompletionsUpdate
{
///
@@ -84,8 +93,8 @@ public partial class StreamingChatCompletionsUpdate
///
///
/// As is the case for non-streaming