Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GenerateAsync extension just return the embedding #5543

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand All @@ -18,7 +19,7 @@ public static class EmbeddingGeneratorExtensions
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
public static Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync<TValue, TEmbedding>(
public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
this IEmbeddingGenerator<TValue, TEmbedding> generator,
TValue value,
EmbeddingGenerationOptions? options = null,
Expand All @@ -28,6 +29,12 @@ public static Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync<TValue, TEmbed
_ = Throw.IfNull(generator);
_ = Throw.IfNull(value);

return generator.GenerateAsync([value], options, cancellationToken);
var embeddings = await generator.GenerateAsync([value], options, cancellationToken).ConfigureAwait(false);
if (embeddings.Count != 1)
{
throw new InvalidOperationException("Expected exactly one embedding to be generated.");
}

return embeddings[0];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, (await service.GenerateAsync("hello"))[0]);
Assert.Same(result, await service.GenerateAsync("hello"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully()
{
SkipIfNotEnabled();

var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET");
var embeddings = await _embeddingGenerator.GenerateAsync(["Using AI with .NET"]);

Assert.NotNull(embeddings.Usage);
Assert.NotNull(embeddings.Usage.InputTokenCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ public async Task CachesSuccessResultsAsync()

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
Assert.Single(result1);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateAsync("abc");

// Assert
Assert.Single(result2);
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2);

// Act/Assert 2: Cache misses do not return cached results
await outer.GenerateAsync(["def"]);
Expand Down Expand Up @@ -144,13 +142,13 @@ public async Task AllowsConcurrentCallsAsync()
Assert.False(result1.IsCompleted);
Assert.False(result2.IsCompleted);
completionTcs.SetResult(true);
AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}

[Fact]
Expand Down Expand Up @@ -218,9 +216,8 @@ public async Task DoesNotCacheCanceledResultsAsync()

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
Assert.Single(result2);
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down Expand Up @@ -254,11 +251,9 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
});

// Assert: Same result
Assert.Single(result1);
Assert.Single(result2);
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down Expand Up @@ -292,11 +287,9 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
});

// Assert: Different results
Assert.Single(result1);
Assert.Single(result2);
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down
Loading