Skip to content

Commit

Permalink
Make GenerateAsync extension just return the embedding (#5543)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Oct 22, 2024
1 parent 1622413 commit e1eb9bd
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand All @@ -18,7 +19,7 @@ public static class EmbeddingGeneratorExtensions
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
public static Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync<TValue, TEmbedding>(
public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
this IEmbeddingGenerator<TValue, TEmbedding> generator,
TValue value,
EmbeddingGenerationOptions? options = null,
Expand All @@ -28,6 +29,12 @@ public static Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync<TValue, TEmbed
_ = Throw.IfNull(generator);
_ = Throw.IfNull(value);

return generator.GenerateAsync([value], options, cancellationToken);
var embeddings = await generator.GenerateAsync([value], options, cancellationToken).ConfigureAwait(false);
if (embeddings.Count != 1)
{
throw new InvalidOperationException("Expected exactly one embedding to be generated.");
}

return embeddings[0];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, (await service.GenerateAsync("hello"))[0]);
Assert.Same(result, await service.GenerateAsync("hello"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully()
{
SkipIfNotEnabled();

var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET");
var embeddings = await _embeddingGenerator.GenerateAsync(["Using AI with .NET"]);

Assert.NotNull(embeddings.Usage);
Assert.NotNull(embeddings.Usage.InputTokenCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ public async Task CachesSuccessResultsAsync()

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
Assert.Single(result1);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateAsync("abc");

// Assert
Assert.Single(result2);
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2);

// Act/Assert 2: Cache misses do not return cached results
await outer.GenerateAsync(["def"]);
Expand Down Expand Up @@ -144,13 +142,13 @@ public async Task AllowsConcurrentCallsAsync()
Assert.False(result1.IsCompleted);
Assert.False(result2.IsCompleted);
completionTcs.SetResult(true);
AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}

[Fact]
Expand Down Expand Up @@ -218,9 +216,8 @@ public async Task DoesNotCacheCanceledResultsAsync()

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
Assert.Single(result2);
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down Expand Up @@ -254,11 +251,9 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
});

// Assert: Same result
Assert.Single(result1);
Assert.Single(result2);
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down Expand Up @@ -292,11 +287,9 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
});

// Assert: Different results
Assert.Single(result1);
Assert.Single(result2);
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}

[Fact]
Expand Down

0 comments on commit e1eb9bd

Please sign in to comment.