From c78d5a3824704f4a90cf573542e7151953da6896 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 18 Oct 2024 21:13:10 -0400 Subject: [PATCH] Make GenerateAsync extension just return the embedding --- .../EmbeddingGeneratorExtensions.cs | 11 ++++++-- .../EmbeddingGeneratorExtensionsTests.cs | 2 +- .../EmbeddingGeneratorIntegrationTests.cs | 2 +- ...istributedCachingEmbeddingGeneratorTest.cs | 27 +++++++------------ 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index fa2a1df4fbe..944ce0995f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -18,7 +19,7 @@ public static class EmbeddingGeneratorExtensions /// The embedding generation options to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embedding for the specified . - public static Task> GenerateAsync( + public static async Task GenerateAsync( this IEmbeddingGenerator generator, TValue value, EmbeddingGenerationOptions? options = null, @@ -28,6 +29,12 @@ public static Task> GenerateAsync>>([result]) }; - Assert.Same(result, (await service.GenerateAsync("hello"))[0]); + Assert.Same(result, await service.GenerateAsync("hello")); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs index 29502f926c6..1929869c487 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs @@ -44,7 +44,7 @@ public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully() { SkipIfNotEnabled(); - var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET"); + var embeddings = await _embeddingGenerator.GenerateAsync(["Using AI with .NET"]); Assert.NotNull(embeddings.Usage); Assert.NotNull(embeddings.Usage.InputTokenCount); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs index 2b4370222c6..9a5086a146d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -44,17 +44,15 @@ public async Task CachesSuccessResultsAsync() // Make the initial request and do a quick sanity check var result1 = await outer.GenerateAsync("abc"); - Assert.Single(result1); - AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result1); Assert.Equal(1, innerCallCount); // Act var result2 = await outer.GenerateAsync("abc"); // Assert - Assert.Single(result2); Assert.Equal(1, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2); // Act/Assert 2: Cache misses do not return cached results await outer.GenerateAsync(["def"]); @@ -144,13 +142,13 @@ public async Task AllowsConcurrentCallsAsync() Assert.False(result1.IsCompleted); Assert.False(result2.IsCompleted); completionTcs.SetResult(true); - AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); - AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]); + AssertEmbeddingsEqual(_expectedEmbedding, await result1); + AssertEmbeddingsEqual(_expectedEmbedding, await result2); // Act 2: Subsequent calls after completion are resolved from the cache var result3 = await outer.GenerateAsync("abc"); Assert.Equal(2, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); + AssertEmbeddingsEqual(_expectedEmbedding, await result1); } [Fact] @@ -218,9 +216,8 @@ public async Task DoesNotCacheCanceledResultsAsync() // Act/Assert: Second call can succeed var result2 = await outer.GenerateAsync("abc"); - Assert.Single(result2); Assert.Equal(2, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2); } [Fact] @@ -254,11 +251,9 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync() }); // Assert: Same result - Assert.Single(result1); - Assert.Single(result2); Assert.Equal(1, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); - AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result1); + AssertEmbeddingsEqual(_expectedEmbedding, result2); } [Fact] @@ -292,11 +287,9 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync() }); // Assert: Different results - Assert.Single(result1); - Assert.Single(result2); Assert.Equal(2, innerCallCount); - AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); - AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result1); + AssertEmbeddingsEqual(_expectedEmbedding, result2); } [Fact]