diff --git a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md index fd55414684f89..65ac4ae4477aa 100644 --- a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md +++ b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md @@ -6,6 +6,8 @@ ### Breaking Changes +- Changed the representation of embeddings from `IReadOnlyList` to `ReadOnlyMemory`. + ### Bugs Fixed ### Other Changes diff --git a/sdk/openai/Azure.AI.OpenAI/README.md b/sdk/openai/Azure.AI.OpenAI/README.md index 7c296debd4bee..bf361355609e3 100644 --- a/sdk/openai/Azure.AI.OpenAI/README.md +++ b/sdk/openai/Azure.AI.OpenAI/README.md @@ -103,7 +103,7 @@ We guarantee that all client instance methods are thread-safe and independent of You can familiarize yourself with different APIs using [Samples](https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/openai/Azure.AI.OpenAI/tests/Samples). -### Generate Chatbot Response +### Generate chatbot response The `GenerateChatbotResponse` method authenticates using a DefaultAzureCredential, then generates text responses to input prompts. @@ -120,7 +120,7 @@ string completion = completionsResponse.Value.Choices[0].Text; Console.WriteLine($"Chatbot: {completion}"); ``` -### Generate Multiple Chatbot Responses With Subscription Key +### Generate multiple chatbot responses with subscription key The `GenerateMultipleChatbotResponsesWithSubscriptionKey` method gives an example of generating text responses to input prompts using an Azure subscription key @@ -152,7 +152,7 @@ foreach (string prompt in examplePrompts) } ``` -### Summarize Text with Completion +### Summarize text with completion The `SummarizeText` method generates a summarization of the given input prompt. @@ -190,7 +190,7 @@ string completion = completionsResponse.Value.Choices[0].Text; Console.WriteLine($"Summarization: {completion}"); ``` -### Stream Chat Messages with non-Azure OpenAI +### Stream chat messages with non-Azure OpenAI ```C# Snippet:StreamChatMessages string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com"; @@ -221,7 +221,7 @@ await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoices } ``` -### Use Chat Functions +### Use chat functions Chat Functions allow a caller of Chat Completions to define capabilities that the model can use to extend its functionality into external tools and data sources. @@ -385,6 +385,18 @@ foreach (ChatMessage contextMessage in message.AzureExtensionsContext.Messages) } ``` +### Generate embeddings + +```C# Snippet:GenerateEmbeddings +string deploymentOrModelName = "text-embedding-ada-002"; +EmbeddingsOptions embeddingsOptions = new("Your text string goes here"); +Response response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsOptions); + +// The response includes the generated embedding. +EmbeddingItem item = response.Value.Data[0]; +ReadOnlyMemory embedding = item.Embedding; +`````` + ### Generate images with DALL-E image generation models ```C# Snippet:GenerateImages diff --git a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs index a4b3f10dcf7df..bfddd9f4fe519 100644 --- a/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs +++ b/sdk/openai/Azure.AI.OpenAI/api/Azure.AI.OpenAI.netstandard2.0.cs @@ -198,7 +198,7 @@ public static partial class AzureOpenAIModelFactory public static Azure.AI.OpenAI.CompletionsUsage CompletionsUsage(int completionTokens = 0, int promptTokens = 0, int totalTokens = 0) { throw null; } public static Azure.AI.OpenAI.ContentFilterResult ContentFilterResult(Azure.AI.OpenAI.ContentFilterSeverity severity = default(Azure.AI.OpenAI.ContentFilterSeverity), bool filtered = false) { throw null; } public static Azure.AI.OpenAI.ContentFilterResults ContentFilterResults(Azure.AI.OpenAI.ContentFilterResult sexual = null, Azure.AI.OpenAI.ContentFilterResult violence = null, Azure.AI.OpenAI.ContentFilterResult hate = null, Azure.AI.OpenAI.ContentFilterResult selfHarm = null, Azure.ResponseError error = null) { throw null; } - public static Azure.AI.OpenAI.EmbeddingItem EmbeddingItem(System.Collections.Generic.IEnumerable embedding = null, int index = 0) { throw null; } + public static Azure.AI.OpenAI.EmbeddingItem EmbeddingItem(System.ReadOnlyMemory embedding = default(System.ReadOnlyMemory), int index = 0) { throw null; } public static Azure.AI.OpenAI.Embeddings Embeddings(System.Collections.Generic.IEnumerable data = null, Azure.AI.OpenAI.EmbeddingsUsage usage = null) { throw null; } public static Azure.AI.OpenAI.EmbeddingsUsage EmbeddingsUsage(int promptTokens = 0, int totalTokens = 0) { throw null; } public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable data = null) { throw null; } @@ -384,7 +384,7 @@ internal ContentFilterResults() { } public partial class EmbeddingItem { internal EmbeddingItem() { } - public System.Collections.Generic.IReadOnlyList Embedding { get { throw null; } } + public System.ReadOnlyMemory Embedding { get { throw null; } } public int Index { get { throw null; } } } public partial class Embeddings diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/EmbeddingItem.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/EmbeddingItem.cs new file mode 100644 index 0000000000000..49858f629e2ca --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/EmbeddingItem.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using Azure.Core; + +namespace Azure.AI.OpenAI +{ + /// Representation of a single embeddings relatedness comparison. + public partial class EmbeddingItem + { + // CUSTOM CODE NOTE: The change of the Embedding property to ReadOnlyMemory + // results in the constructor below being generated twice and thus a build error. + // We're adding it here verbatim as custom code to work around this issue. + + /// Initializes a new instance of EmbeddingItem. + /// + /// List of embeddings value for the input prompt. These represent a measurement of the + /// vector-based relatedness of the provided input. + /// + /// Index of the prompt to which the EmbeddingItem corresponds. + internal EmbeddingItem(ReadOnlyMemory embedding, int index) + { + Embedding = embedding; + Index = index; + } + + // CUSTOM CODE NOTE: We want to represent embeddings as ReadOnlyMemory instead + // of IReadOnlyList. + public ReadOnlyMemory Embedding { get; } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIModelFactory.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIModelFactory.cs index 4d48aec9e8275..b73982221f51b 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIModelFactory.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/AzureOpenAIModelFactory.cs @@ -33,11 +33,9 @@ public static Embeddings Embeddings(IEnumerable data = null, Embe /// /// Index of the prompt to which the EmbeddingItem corresponds. /// A new instance for mocking. - public static EmbeddingItem EmbeddingItem(IEnumerable embedding = null, int index = default) + public static EmbeddingItem EmbeddingItem(ReadOnlyMemory embedding = default, int index = default) { - embedding ??= new List(); - - return new EmbeddingItem(embedding?.ToList(), index); + return new EmbeddingItem(embedding, index); } /// Initializes a new instance of EmbeddingsUsage. diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.Serialization.cs index 73fcef3c328fe..d0beb31615985 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.Serialization.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.Serialization.cs @@ -5,7 +5,7 @@ #nullable disable -using System.Collections.Generic; +using System; using System.Text.Json; using Azure; @@ -19,18 +19,24 @@ internal static EmbeddingItem DeserializeEmbeddingItem(JsonElement element) { return null; } - IReadOnlyList embedding = default; + ReadOnlyMemory embedding = default; int index = default; foreach (var property in element.EnumerateObject()) { if (property.NameEquals("embedding"u8)) { - List array = new List(); + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + int index0 = 0; + float[] array = new float[property.Value.GetArrayLength()]; foreach (var item in property.Value.EnumerateArray()) { - array.Add(item.GetSingle()); + array[index0] = item.GetSingle(); + index0++; } - embedding = array; + embedding = new ReadOnlyMemory(array); continue; } if (property.NameEquals("index"u8)) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.cs index b1cd9e879f18f..928746b1982e8 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.cs @@ -6,47 +6,12 @@ #nullable disable using System; -using System.Collections.Generic; -using System.Linq; -using Azure.Core; namespace Azure.AI.OpenAI { /// Representation of a single embeddings relatedness comparison. public partial class EmbeddingItem { - /// Initializes a new instance of EmbeddingItem. - /// - /// List of embeddings value for the input prompt. These represent a measurement of the - /// vector-based relatedness of the provided input. - /// - /// Index of the prompt to which the EmbeddingItem corresponds. - /// is null. - internal EmbeddingItem(IEnumerable embedding, int index) - { - Argument.AssertNotNull(embedding, nameof(embedding)); - - Embedding = embedding.ToList(); - Index = index; - } - - /// Initializes a new instance of EmbeddingItem. - /// - /// List of embeddings value for the input prompt. These represent a measurement of the - /// vector-based relatedness of the provided input. - /// - /// Index of the prompt to which the EmbeddingItem corresponds. - internal EmbeddingItem(IReadOnlyList embedding, int index) - { - Embedding = embedding; - Index = index; - } - - /// - /// List of embeddings value for the input prompt. These represent a measurement of the - /// vector-based relatedness of the provided input. - /// - public IReadOnlyList Embedding { get; } /// Index of the prompt to which the EmbeddingItem corresponds. public int Index { get; } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs index f7eaaf696bc35..2b9034d25e912 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs @@ -84,6 +84,13 @@ public async Task Embeddings(OpenAIClientServiceTarget serviceTarget) Assert.That(embeddingsRequest, Is.InstanceOf()); Response response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsRequest); Assert.That(response, Is.InstanceOf>()); + Assert.That(response.Value, Is.Not.Null); + Assert.That(response.Value.Data, Is.Not.Null.Or.Empty); + + EmbeddingItem firstItem = response.Value.Data[0]; + Assert.That(firstItem, Is.Not.Null); + Assert.That(firstItem.Index, Is.EqualTo(0)); + Assert.That(firstItem.Embedding, Is.Not.Null.Or.Empty); } [RecordedTest] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample06_ImageGenerations.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample06_ImageGenerations.cs index d6b9cd5ecf6af..f8416d63665e3 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample06_ImageGenerations.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample06_ImageGenerations.cs @@ -9,7 +9,7 @@ namespace Azure.AI.OpenAI.Tests.Samples { - public partial class StreamingChat + public partial class ImagesSamples { [Test] [Ignore("Only verifying that the sample builds")] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample09_TranscribeAudio.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample09_TranscribeAudio.cs index fb77d2f21993e..4ee6c7e8a0552 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample09_TranscribeAudio.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample09_TranscribeAudio.cs @@ -9,7 +9,7 @@ namespace Azure.AI.OpenAI.Tests.Samples { - public partial class StreamingChat + public partial class AudioSamples { [Test] [Ignore("Only verifying that the sample builds")] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample10_TranslateAudio.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample10_TranslateAudio.cs index 8b3bfabe95ec1..21a1885949885 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample10_TranslateAudio.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample10_TranslateAudio.cs @@ -9,7 +9,7 @@ namespace Azure.AI.OpenAI.Tests.Samples { - public partial class StreamingChat + public partial class AudioSamples { [Test] [Ignore("Only verifying that the sample builds")] diff --git a/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample11_Embeddings.cs b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample11_Embeddings.cs new file mode 100644 index 0000000000000..92278a8272e33 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample11_Embeddings.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using Azure.Identity; +using NUnit.Framework; + +namespace Azure.AI.OpenAI.Tests.Samples +{ + public partial class EmbeddingsSamples + { + [Test] + [Ignore("Only verifying that the sample builds")] + public async Task GenerateEmbeddings() + { + string endpoint = "https://myaccount.openai.azure.com/"; + var client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential()); + + #region Snippet:GenerateEmbeddings + string deploymentOrModelName = "text-embedding-ada-002"; + EmbeddingsOptions embeddingsOptions = new("Your text string goes here"); + Response response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsOptions); + + // The response includes the generated embedding. + EmbeddingItem item = response.Value.Data[0]; + ReadOnlyMemory embedding = item.Embedding; + #endregion + } + } +}