Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a few issues in IChatClient implementations #5549

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.Extensions.AI;

/// <summary>
/// Represents text content in a chat.
/// </summary>
public sealed class TextContent : AIContent
{
private string? _text;

/// <summary>
/// Initializes a new instance of the <see cref="TextContent"/> class.
/// </summary>
/// <param name="text">The text content.</param>
public TextContent(string? text)
{
Text = text;
_text = text;
}

/// <summary>
/// Gets or sets the text content.
/// </summary>
public string? Text { get; set; }
[AllowNull]
public string Text
{
get => _text ?? string.Empty;
set => _text = value;
}

/// <inheritdoc/>
public override string ToString() => Text ?? string.Empty;
public override string ToString() => Text;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
Expand Down Expand Up @@ -410,13 +409,13 @@ private sealed class AzureAIChatToolJson
private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerable<ChatMessage> inputs)
{
// Maps all of the M.E.AI types to the corresponding AzureAI types.
// Unrecognized content is ignored.
// Unrecognized or non-processable content is ignored.

foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System)
{
yield return new ChatRequestSystemMessage(input.Text);
yield return new ChatRequestSystemMessage(input.Text ?? string.Empty);
}
else if (input.Role == ChatRole.Tool)
{
Expand Down Expand Up @@ -444,52 +443,64 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
}
else if (input.Role == ChatRole.User)
{
yield return new ChatRequestUserMessage(input.Contents.Select(static (AIContent item) => item switch
{
TextContent textContent => new ChatMessageTextContentItem(textContent.Text),
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType) :
imageContent.Uri is string uri ? new ChatMessageImageContentItem(new Uri(uri)) :
(ChatMessageContentItem?)null,
_ => null,
}).Where(c => c is not null));
yield return new ChatRequestUserMessage(GetContentParts(input.Contents));
}
else if (input.Role == ChatRole.Assistant)
{
Dictionary<string, ChatCompletionsToolCall>? toolCalls = null;
// TODO: ChatRequestAssistantMessage only enables text content currently.
// Update it with other content types when it supports that.
ChatRequestAssistantMessage message = new()
{
Content = input.Text
};

foreach (var content in input.Contents)
{
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
if (content is FunctionCallContent { CallId: not null } callRequest)
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
string jsonArguments = JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)));
(toolCalls ??= []).Add(
message.ToolCalls.Add(new ChatCompletionsFunctionToolCall(
callRequest.CallId,
new ChatCompletionsFunctionToolCall(
callRequest.CallId,
callRequest.Name,
jsonArguments));
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)))));
}
}

ChatRequestAssistantMessage message = new();
if (toolCalls is not null)
{
foreach (var entry in toolCalls)
{
message.ToolCalls.Add(entry.Value);
}
}
else
{
message.Content = input.Text;
}

yield return message;
}
}
}

/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentItem"/>.</summary>
private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> contents)
{
List<ChatMessageContentItem> parts = [];
foreach (var content in contents)
{
switch (content)
{
case TextContent textContent:
(parts ??= []).Add(new ChatMessageTextContentItem(textContent.Text));
break;

case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
(parts ??= []).Add(new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType));
break;

case ImageContent imageContent when imageContent.Uri is string uri:
(parts ??= []).Add(new ChatMessageImageContentItem(new Uri(uri)));
break;
}
}

if (parts.Count == 0)
{
parts.Add(new ChatMessageTextContentItem(string.Empty));
}

return parts;
}

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
Expand Down
66 changes: 42 additions & 24 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
Expand Down Expand Up @@ -569,13 +568,16 @@ private sealed class OpenAIChatToolJson
private IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs)
{
// Maps all of the M.E.AI types to the corresponding OpenAI types.
// Unrecognized content is ignored.
// Unrecognized or non-processable content is ignored.

foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System)
if (input.Role == ChatRole.System || input.Role == ChatRole.User)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
yield return new SystemChatMessage(input.Text) { ParticipantName = input.AuthorName };
var parts = GetContentParts(input.Contents);
yield return input.Role == ChatRole.System ?
new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Tool)
{
Expand All @@ -601,39 +603,25 @@ private sealed class OpenAIChatToolJson
}
}
}
else if (input.Role == ChatRole.User)
{
yield return new UserChatMessage(input.Contents.Select(static (AIContent item) => item switch
{
TextContent textContent => ChatMessageContentPart.CreateTextPart(textContent.Text),
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType) :
imageContent.Uri is string uri ? ChatMessageContentPart.CreateImagePart(new Uri(uri)) :
null,
_ => null,
}).Where(c => c is not null))
{ ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Assistant)
{
Dictionary<string, ChatToolCall>? toolCalls = null;
AssistantChatMessage message = new(GetContentParts(input.Contents))
{
ParticipantName = input.AuthorName
};

foreach (var content in input.Contents)
{
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
if (content is FunctionCallContent { CallId: not null } callRequest)
{
(toolCalls ??= []).Add(
callRequest.CallId,
message.ToolCalls.Add(
ChatToolCall.CreateFunctionToolCall(
callRequest.CallId,
callRequest.Name,
BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions)));
}
}

AssistantChatMessage message = toolCalls is not null ?
new(toolCalls.Values) { ParticipantName = input.AuthorName } :
new(input.Text) { ParticipantName = input.AuthorName };

if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true)
{
message.Refusal = refusal;
Expand All @@ -644,6 +632,36 @@ private sealed class OpenAIChatToolJson
}
}

/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentPart"/>.</summary>
private static List<ChatMessageContentPart> GetContentParts(IList<AIContent> contents)
{
List<ChatMessageContentPart> parts = [];
foreach (var content in contents)
{
switch (content)
{
case TextContent textContent:
(parts ??= []).Add(ChatMessageContentPart.CreateTextPart(textContent.Text));
break;

case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType));
break;

case ImageContent imageContent when imageContent.Uri is string uri:
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(new Uri(uri)));
break;
}
}

if (parts.Count == 0)
{
parts.Add(ChatMessageContentPart.CreateTextPart(string.Empty));
}

return parts;
}

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public void Constructor_String_PropsDefault(string? text)
TextContent c = new(text);
Assert.Null(c.RawRepresentation);
Assert.Null(c.AdditionalProperties);
Assert.Equal(text, c.Text);
Assert.Equal(text ?? string.Empty, c.Text);
}

[Fact]
Expand All @@ -34,13 +34,17 @@ public void Constructor_PropsRoundtrip()
c.AdditionalProperties = props;
Assert.Same(props, c.AdditionalProperties);

Assert.Null(c.Text);
Assert.Equal(string.Empty, c.Text);
c.Text = "text";
Assert.Equal("text", c.Text);
Assert.Equal("text", c.ToString());

c.Text = null;
Assert.Null(c.Text);
Assert.Equal(string.Empty, c.Text);
Assert.Equal(string.Empty, c.ToString());

c.Text = string.Empty;
Assert.Equal(string.Empty, c.Text);
Assert.Equal(string.Empty, c.ToString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,89 @@ public async Task MultipleMessages_NonStreaming()
Assert.Equal(57, response.Usage.TotalTokenCount);
}

[Fact]
public async Task NullAssistantText_ContentSkipped_NonStreaming()
{
const string Input = """
{
"messages": [
{
"role": "assistant"
},
{
"content": [
{
"text": "hello!",
"type": "text"
}
],
"role": "user"
}
],
"model": "gpt-4o-mini"
}
""";

const string Output = """
{
"id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P",
"object": "chat.completion",
"created": 1727894187,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello.",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 42,
"completion_tokens": 15,
"total_tokens": 57,
"prompt_tokens_details": {
"cached_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0
}
},
"system_fingerprint": "fp_f85bea6784"
}
""";

using VerbatimHttpHandler handler = new(Input, Output);
using HttpClient httpClient = new(handler);
using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini");

List<ChatMessage> messages =
[
new(ChatRole.Assistant, (string?)null),
new(ChatRole.User, "hello!"),
];

var response = await client.CompleteAsync(messages);
Assert.NotNull(response);

Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId);
Assert.Equal("Hello.", response.Message.Text);
Assert.Single(response.Message.Contents);
Assert.Equal(ChatRole.Assistant, response.Message.Role);
Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId);
Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt);
Assert.Equal(ChatFinishReason.Stop, response.FinishReason);

Assert.NotNull(response.Usage);
Assert.Equal(42, response.Usage.InputTokenCount);
Assert.Equal(15, response.Usage.OutputTokenCount);
Assert.Equal(57, response.Usage.TotalTokenCount);
}

[Fact]
public async Task FunctionCallContent_NonStreaming()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ private int CountTokens(ChatMessage message)
int sum = 0;
foreach (AIContent content in message.Contents)
{
if ((content as TextContent)?.Text is string text)
if (content is TextContent text)
{
sum += _tokenizer.CountTokens(text);
sum += _tokenizer.CountTokens(text.Text);
}
}

Expand Down
Loading
Loading