Skip to content

Commit

Permalink
.Net: Amazon Connector - InnerContent Metadata Support (#10086)
Browse files Browse the repository at this point in the history
### Motivation and Context

Amazon connector was not exposing the breaking glass ability to expose
the underlying component used by the Bedrock SDK for retrieving extra
metadata associated with ChatCOmpletion and TextGeneration requests.
This was creating a problem for implementers wanting to capture extra
information like: stop reason, usage, request id and more...

### Description

- Add missing InnerContent support for metadata (including Usage when
available).
- Fixes #9989 
- Fix warnings (Amazon projects) in the new Analyzer version.

---------

Co-authored-by: tanaka_733 <[email protected]>
  • Loading branch information
RogerBarreto and tanaka-takayoshi authored Jan 7, 2025
1 parent 9583f5f commit 4a70658
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<TargetFramework>net8.0</TargetFramework>
<Nullable>enable</Nullable>
<RootNamespace>AmazonBedrockAIModels</RootNamespace>
<UserSecretsId>5ee045b0-aea3-4f08-8d31-32d1a6f8fed0</UserSecretsId>
</PropertyGroup>
<PropertyGroup>
<NoWarn>$(NoWarn);SKEXP0001;SKEXP0070</NoWarn>
Expand Down
11 changes: 9 additions & 2 deletions dotnet/samples/Demos/AmazonBedrockModels/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using Amazon.BedrockRuntime.Model;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.TextGeneration;
Expand All @@ -28,12 +30,14 @@
case 4:
await PerformStreamTextGeneration().ConfigureAwait(false);
break;
default:
throw new InvalidOperationException("Invalid choice");
}

async Task PerformChatCompletion()
{
string userInput;
ChatHistory chatHistory = new();
ChatHistory chatHistory = [];

// Get available chat completion models
var availableChatModels = bedrockModels.Values
Expand All @@ -60,6 +64,8 @@ async Task PerformChatCompletion()
{
output += message.Content;
Console.WriteLine($"Chat Completion Answer: {message.Content}");
var innerContent = message.InnerContent as ConverseResponse;
Console.WriteLine($"Usage Metadata: {JsonSerializer.Serialize(innerContent?.Usage)}");
Console.WriteLine();
}

Expand Down Expand Up @@ -91,6 +97,7 @@ async Task PerformTextGeneration()
if (firstTextContent != null)
{
Console.WriteLine("Text Generation Answer: " + firstTextContent.Text);
Console.WriteLine($"Metadata: {JsonSerializer.Serialize(firstTextContent.InnerContent)}");
}
else
{
Expand All @@ -106,7 +113,7 @@ async Task PerformTextGeneration()
async Task PerformStreamChatCompletion()
{
string userInput;
ChatHistory streamChatHistory = new();
ChatHistory streamChatHistory = [];

// Get available streaming chat completion models
var availableStreamingChatModels = bedrockModels.Values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ public async Task GetChatMessageContentsAsyncShouldReturnChatMessageContentsAsyn
Assert.Equal(AuthorRole.Assistant, result[0].Role);
Assert.Single(result[0].Items);
Assert.Equal("Hello, world!", result[0].Items[0].ToString());
Assert.NotNull(result[0].InnerContent);
}

/// <summary>
Expand Down Expand Up @@ -160,6 +161,7 @@ public async Task GetStreamingChatMessageContentsAsyncShouldReturnStreamedChatMe
Assert.NotNull(item);
Assert.NotNull(item.Content);
Assert.NotNull(item.Role);
Assert.NotNull(item.InnerContent);
output.Add(item);
}
Assert.True(output.Count > 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ public async Task GetTextContentsAsyncShouldReturnTextContentsAsync()
// Assert
Assert.Single(result);
Assert.Equal("This is a mock output.", result[0].Text);
Assert.NotNull(result[0].InnerContent);
}

/// <summary>
Expand Down Expand Up @@ -210,6 +211,7 @@ public async Task GetStreamingTextContentsAsyncShouldReturnStreamedTextContentsA
iterations += 1;
Assert.NotNull(item);
Assert.NotNull(item.Text);
Assert.NotNull(item.InnerContent);
result.Add(item);
}
Assert.True(iterations > 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ private ChatMessageContent[] ConvertToMessageContent(ConverseResponse response)
new ChatMessageContent
{
Role = BedrockClientUtilities.MapConversationRoleToAuthorRole(message.Role.Value),
Items = CreateChatMessageContentItemCollection(message.Content)
Items = CreateChatMessageContentItemCollection(message.Content),
InnerContent = response
}
];
}
Expand Down Expand Up @@ -192,11 +193,11 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageAs
List<StreamingChatMessageContent>? streamedContents = activity is not null ? [] : null;
foreach (var chunk in response.Stream.AsEnumerable())
{
if (chunk is ContentBlockDeltaEvent)
if (chunk is ContentBlockDeltaEvent deltaEvent)
{
// Convert output to semantic kernel's StreamingChatMessageContent
var c = (chunk as ContentBlockDeltaEvent)?.Delta.Text;
var content = new StreamingChatMessageContent(AuthorRole.Assistant, c);
var c = deltaEvent?.Delta.Text;
var content = new StreamingChatMessageContent(AuthorRole.Assistant, c, deltaEvent);
streamedContents?.Add(content);
yield return content;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,11 @@ internal async IAsyncEnumerable<StreamingTextContent> StreamTextAsync(
{
continue;
}
IEnumerable<string> texts = this._ioTextService.GetTextStreamOutput(chunk);
foreach (var text in texts)

foreach (var streamingContent in this._ioTextService.GetTextStreamOutput(chunk))
{
var content = new StreamingTextContent(text);
streamedContents?.Add(content);
yield return new StreamingTextContent(text);
streamedContents?.Add(streamingContent);
yield return streamingContent;
}
}
activity?.SetStatus(BedrockClientUtilities.ConvertHttpStatusCodeToActivityStatusCode(streamingResponse.HttpStatusCode));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ internal interface IBedrockTextGenerationService
/// </summary>
/// <param name="chunk">The payloadPart bytes provided from the streaming response.</param>
/// <returns><see cref="IEnumerable{String}"/> output strings.</returns>
internal IEnumerable<string> GetTextStreamOutput(JsonNode chunk);
internal IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<AI21JambaResponse.AI21TextResponse>(reader.ReadToEnd());
List<TextContent> textContents = [];

if (responseBody?.Choices is not { Count: > 0 })
{
return textContents;
return [];
}
textContents.AddRange(responseBody.Choices.Select(choice => new TextContent(choice.Message?.Content)));
return textContents;

return responseBody.Choices
.Select(choice => new TextContent(choice.Message?.Content, innerContent: responseBody))
.ToList();
}

/// <inheritdoc/>
Expand Down Expand Up @@ -108,12 +110,12 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var choiceDeltaContent = chunk["choices"]?[0]?["delta"]?["content"];
if (choiceDeltaContent is not null)
{
yield return choiceDeltaContent.ToString();
yield return new StreamingTextContent(choiceDeltaContent.ToString(), innerContent: chunk);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<AI21JurassicResponse>(reader.ReadToEnd());
List<TextContent> textContents = [];

if (responseBody?.Completions is not { Count: > 0 })
{
return textContents;
return [];
}
textContents.AddRange(responseBody.Completions.Select(completion => new TextContent(completion.Data?.Text)));
return textContents;

return responseBody.Completions
.Select(completion => new TextContent(completion.Data?.Text, innerContent: responseBody))
.ToList();
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
throw new NotSupportedException("Streaming not supported by this model.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<TitanTextResponse>(reader.ReadToEnd());
List<TextContent> textContents = [];
if (responseBody?.Results is not { Count: > 0 })
{
return textContents;
return [];
}

string? outputText = responseBody.Results[0].OutputText;
textContents.Add(new TextContent(outputText));
return textContents;
return [new TextContent(outputText, innerContent: responseBody)];
}

/// <inheritdoc/>
Expand Down Expand Up @@ -85,12 +84,12 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var text = chunk["outputText"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return text!;
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
List<TextContent> textContents = [];
if (!string.IsNullOrEmpty(responseBody?.Completion))
{
textContents.Add(new TextContent(responseBody!.Completion));
textContents.Add(new TextContent(responseBody!.Completion, innerContent: responseBody));
}

return textContents;
Expand Down Expand Up @@ -120,12 +120,12 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var text = chunk["completion"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return text!;
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
List<TextContent> textContents = [];
if (!string.IsNullOrEmpty(responseBody?.Text))
{
textContents.Add(new TextContent(responseBody!.Text));
textContents.Add(new TextContent(responseBody!.Text, innerContent: responseBody));
}

return textContents;
Expand Down Expand Up @@ -140,12 +140,12 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var text = chunk["text"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return text!;
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<CommandResponse>(reader.ReadToEnd());
List<TextContent> textContents = [];

if (responseBody?.Generations is not { Count: > 0 })
{
return textContents;
return [];
}
textContents.AddRange(from generation in responseBody.Generations
where !string.IsNullOrEmpty(generation.Text)
select new TextContent(generation.Text));
return textContents;

return responseBody.Generations
.Where(g => !string.IsNullOrEmpty(g.Text))
.Select(g => new TextContent(g.Text, innerContent: responseBody))
.ToList();
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var generations = chunk["generations"]?.AsArray();
if (generations != null)
Expand All @@ -63,7 +64,7 @@ public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
var text = generation?["text"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return text!;
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
List<TextContent> textContents = [];
if (!string.IsNullOrEmpty(responseBody?.Generation))
{
textContents.Add(new TextContent(responseBody!.Generation));
textContents.Add(new TextContent(responseBody!.Generation, innerContent: responseBody));
}

return textContents;
Expand Down Expand Up @@ -75,12 +75,12 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var generation = chunk["generation"]?.ToString();
if (!string.IsNullOrEmpty(generation))
{
yield return generation!;
yield return new StreamingTextContent(generation, innerContent: chunk)!;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ public IReadOnlyList<TextContent> GetInvokeResponseBody(InvokeModelResponse resp
{
using var reader = new StreamReader(response.Body);
var responseBody = JsonSerializer.Deserialize<MistralResponse>(reader.ReadToEnd());
List<TextContent> textContents = [];
if (responseBody?.Outputs is not { Count: > 0 })
{
return textContents;
return [];
}
textContents.AddRange(responseBody.Outputs.Select(output => new TextContent(output.Text)));
return textContents;

return responseBody.Outputs
.Select(output => new TextContent(output.Text, innerContent: responseBody))
.ToList();
}

/// <inheritdoc/>
Expand Down Expand Up @@ -82,7 +83,7 @@ public ConverseRequest GetConverseRequest(string modelId, ChatHistory chatHistor
}

/// <inheritdoc/>
public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
public IEnumerable<StreamingTextContent> GetTextStreamOutput(JsonNode chunk)
{
var outputs = chunk["outputs"]?.AsArray();
if (outputs != null)
Expand All @@ -92,7 +93,7 @@ public IEnumerable<string> GetTextStreamOutput(JsonNode chunk)
var text = output?["text"]?.ToString();
if (!string.IsNullOrEmpty(text))
{
yield return text!;
yield return new StreamingTextContent(text, innerContent: chunk)!;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public async Task ChatCompletionReturnsValidResponseAsync(string modelId)
foreach (var message in response)
{
output += message.Content;
Assert.NotNull(message.InnerContent);
}
chatHistory.AddAssistantMessage(output);

Expand Down Expand Up @@ -89,6 +90,7 @@ public async Task ChatStreamingReturnsValidResponseAsync(string modelId)
await foreach (var message in response)
{
output += message.Content;
Assert.NotNull(message.InnerContent);
}
chatHistory.AddAssistantMessage(output);

Expand Down
Loading

0 comments on commit 4a70658

Please sign in to comment.