Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure OpenAI] Change representation of embeddings from IReadOnlyList<float> to ReadOnlyMemory<float> #39679

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Breaking Changes

- Changed the representation of embeddings from `IReadOnlyList<float>` to `ReadOnlyMemory<float>`.

### Bugs Fixed

### Other Changes
Expand Down
22 changes: 17 additions & 5 deletions sdk/openai/Azure.AI.OpenAI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ We guarantee that all client instance methods are thread-safe and independent of

You can familiarize yourself with different APIs using [Samples](https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/openai/Azure.AI.OpenAI/tests/Samples).

### Generate Chatbot Response
### Generate chatbot response

The `GenerateChatbotResponse` method authenticates using a DefaultAzureCredential, then generates text responses to input prompts.

Expand All @@ -120,7 +120,7 @@ string completion = completionsResponse.Value.Choices[0].Text;
Console.WriteLine($"Chatbot: {completion}");
```

### Generate Multiple Chatbot Responses With Subscription Key
### Generate multiple chatbot responses with subscription key

The `GenerateMultipleChatbotResponsesWithSubscriptionKey` method gives an example of generating text responses to input prompts using an Azure subscription key

Expand Down Expand Up @@ -152,7 +152,7 @@ foreach (string prompt in examplePrompts)
}
```

### Summarize Text with Completion
### Summarize text with completion

The `SummarizeText` method generates a summarization of the given input prompt.

Expand Down Expand Up @@ -190,7 +190,7 @@ string completion = completionsResponse.Value.Choices[0].Text;
Console.WriteLine($"Summarization: {completion}");
```

### Stream Chat Messages with non-Azure OpenAI
### Stream chat messages with non-Azure OpenAI

```C# Snippet:StreamChatMessages
string nonAzureOpenAIApiKey = "your-api-key-from-platform.openai.com";
Expand Down Expand Up @@ -221,7 +221,7 @@ await foreach (StreamingChatChoice choice in streamingChatCompletions.GetChoices
}
```

### Use Chat Functions
### Use chat functions

Chat Functions allow a caller of Chat Completions to define capabilities that the model can use to extend its
functionality into external tools and data sources.
Expand Down Expand Up @@ -385,6 +385,18 @@ foreach (ChatMessage contextMessage in message.AzureExtensionsContext.Messages)
}
```

### Generate embeddings

```C# Snippet:GenerateEmbeddings
string deploymentOrModelName = "text-embedding-ada-002";
EmbeddingsOptions embeddingsOptions = new("Your text string goes here");
Response<Embeddings> response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsOptions);

// The response includes the generated embedding.
EmbeddingItem item = response.Value.Data[0];
ReadOnlyMemory<float> embedding = item.Embedding;
``````

### Generate images with DALL-E image generation models

```C# Snippet:GenerateImages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public static partial class AzureOpenAIModelFactory
public static Azure.AI.OpenAI.CompletionsUsage CompletionsUsage(int completionTokens = 0, int promptTokens = 0, int totalTokens = 0) { throw null; }
public static Azure.AI.OpenAI.ContentFilterResult ContentFilterResult(Azure.AI.OpenAI.ContentFilterSeverity severity = default(Azure.AI.OpenAI.ContentFilterSeverity), bool filtered = false) { throw null; }
public static Azure.AI.OpenAI.ContentFilterResults ContentFilterResults(Azure.AI.OpenAI.ContentFilterResult sexual = null, Azure.AI.OpenAI.ContentFilterResult violence = null, Azure.AI.OpenAI.ContentFilterResult hate = null, Azure.AI.OpenAI.ContentFilterResult selfHarm = null, Azure.ResponseError error = null) { throw null; }
public static Azure.AI.OpenAI.EmbeddingItem EmbeddingItem(System.Collections.Generic.IEnumerable<float> embedding = null, int index = 0) { throw null; }
public static Azure.AI.OpenAI.EmbeddingItem EmbeddingItem(System.ReadOnlyMemory<float> embedding = default(System.ReadOnlyMemory<float>), int index = 0) { throw null; }
public static Azure.AI.OpenAI.Embeddings Embeddings(System.Collections.Generic.IEnumerable<Azure.AI.OpenAI.EmbeddingItem> data = null, Azure.AI.OpenAI.EmbeddingsUsage usage = null) { throw null; }
public static Azure.AI.OpenAI.EmbeddingsUsage EmbeddingsUsage(int promptTokens = 0, int totalTokens = 0) { throw null; }
public static Azure.AI.OpenAI.ImageGenerations ImageGenerations(System.DateTimeOffset created = default(System.DateTimeOffset), System.Collections.Generic.IEnumerable<Azure.AI.OpenAI.ImageLocation> data = null) { throw null; }
Expand Down Expand Up @@ -384,7 +384,7 @@ internal ContentFilterResults() { }
public partial class EmbeddingItem
{
internal EmbeddingItem() { }
public System.Collections.Generic.IReadOnlyList<float> Embedding { get { throw null; } }
public System.ReadOnlyMemory<float> Embedding { get { throw null; } }
public int Index { get { throw null; } }
}
public partial class Embeddings
Expand Down
36 changes: 36 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/src/Custom/EmbeddingItem.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using Azure.Core;

namespace Azure.AI.OpenAI
{
/// <summary> Representation of a single embeddings relatedness comparison. </summary>
public partial class EmbeddingItem
{
// CUSTOM CODE NOTE: The change of the Embedding property to ReadOnlyMemory<float>
// results in the constructor below being generated twice and thus a build error.
// We're adding it here verbatim as custom code to work around this issue.

/// <summary> Initializes a new instance of EmbeddingItem. </summary>
/// <param name="embedding">
/// List of embeddings value for the input prompt. These represent a measurement of the
/// vector-based relatedness of the provided input.
/// </param>
/// <param name="index"> Index of the prompt to which the EmbeddingItem corresponds. </param>
internal EmbeddingItem(ReadOnlyMemory<float> embedding, int index)
{
Embedding = embedding;
Index = index;
}

// CUSTOM CODE NOTE: We want to represent embeddings as ReadOnlyMemory<float> instead
// of IReadOnlyList<float>.
public ReadOnlyMemory<float> Embedding { get; }
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 0 additions & 35 deletions sdk/openai/Azure.AI.OpenAI/src/Generated/EmbeddingItem.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ public async Task Embeddings(OpenAIClientServiceTarget serviceTarget)
Assert.That(embeddingsRequest, Is.InstanceOf<EmbeddingsOptions>());
Response<Embeddings> response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsRequest);
Assert.That(response, Is.InstanceOf<Response<Embeddings>>());
Assert.That(response.Value, Is.Not.Null);
Assert.That(response.Value.Data, Is.Not.Null.Or.Empty);

EmbeddingItem firstItem = response.Value.Data[0];
Assert.That(firstItem, Is.Not.Null);
Assert.That(firstItem.Index, Is.EqualTo(0));
Assert.That(firstItem.Embedding, Is.Not.Null.Or.Empty);
}

[RecordedTest]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class StreamingChat
public partial class ImagesSamples
{
[Test]
[Ignore("Only verifying that the sample builds")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class StreamingChat
public partial class AudioSamples
{
[Test]
[Ignore("Only verifying that the sample builds")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class StreamingChat
public partial class AudioSamples
{
[Test]
[Ignore("Only verifying that the sample builds")]
Expand Down
31 changes: 31 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample11_Embeddings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Threading.Tasks;
using Azure.Identity;
using NUnit.Framework;

namespace Azure.AI.OpenAI.Tests.Samples
{
public partial class EmbeddingsSamples
{
[Test]
[Ignore("Only verifying that the sample builds")]
public async Task GenerateEmbeddings()
{
string endpoint = "https://myaccount.openai.azure.com/";
var client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());

#region Snippet:GenerateEmbeddings
string deploymentOrModelName = "text-embedding-ada-002";
EmbeddingsOptions embeddingsOptions = new("Your text string goes here");
Response<Embeddings> response = await client.GetEmbeddingsAsync(deploymentOrModelName, embeddingsOptions);

// The response includes the generated embedding.
EmbeddingItem item = response.Value.Data[0];
ReadOnlyMemory<float> embedding = item.Embedding;
#endregion
}
}
}