diff --git a/eng/MSBuild/Shared.props b/eng/MSBuild/Shared.props index a68b0e4298f..dee583f7e39 100644 --- a/eng/MSBuild/Shared.props +++ b/eng/MSBuild/Shared.props @@ -14,6 +14,10 @@ + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/JsonModelHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/JsonModelHelpers.cs new file mode 100644 index 00000000000..5f6b92d2f01 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/JsonModelHelpers.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel.Primitives; + +namespace Microsoft.Extensions.AI; + +/// +/// Defines a set of helper methods for working with types. +/// +internal static class JsonModelHelpers +{ + public static BinaryData Serialize(TModel value) + where TModel : IJsonModel + { + return value.Write(ModelReaderWriterOptions.Json); + } + + public static TModel Deserialize(BinaryData data) + where TModel : IJsonModel, new() + { + return JsonModelDeserializationWitness.Value.Create(data, ModelReaderWriterOptions.Json); + } + + private sealed class JsonModelDeserializationWitness + where TModel : IJsonModel, new() + { + public static readonly IJsonModel Value = new TModel(); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj index d3d09766d69..d3e969337e6 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj @@ -18,11 +18,15 @@ $(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002;OPENAI002 true true + true true true + true + true + true diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index 09ad9aa18ac..d0ec35d1e22 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -4,11 +4,7 @@ using System; using System.Collections.Generic; using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; -using System.Text.Json.Serialization; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -16,7 +12,6 @@ using OpenAI.Chat; #pragma warning disable S1067 // Expressions should not be too complex -#pragma warning disable S1135 // Track uses of "TODO" tags #pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields #pragma warning disable SA1204 // Static elements should appear before instance elements #pragma warning disable SA1108 // Block statements should not contain embedded comments @@ -26,8 +21,6 @@ namespace Microsoft.Extensions.AI; /// Represents an for an OpenAI or . public sealed class OpenAIChatClient : IChatClient { - private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement; - /// Default OpenAI endpoint. private static readonly Uri _defaultOpenAIEndpoint = new("https://api.openai.com/v1"); @@ -110,224 +103,28 @@ public async Task CompleteAsync( { _ = Throw.IfNull(chatMessages); - // Make the call to OpenAI. - OpenAI.Chat.ChatCompletion response = (await _chatClient.CompleteChatAsync( - ToOpenAIChatMessages(chatMessages), - ToOpenAIOptions(options), - cancellationToken).ConfigureAwait(false)).Value; - - // Create the return message. - ChatMessage returnMessage = new() - { - RawRepresentation = response, - Role = ToChatRole(response.Role), - }; - - // Populate its content from those in the OpenAI response content. - foreach (ChatMessageContentPart contentPart in response.Content) - { - if (ToAIContent(contentPart) is AIContent aiContent) - { - returnMessage.Contents.Add(aiContent); - } - } - - // Also manufacture function calling content items from any tool calls in the response. - if (options?.Tools is { Count: > 0 }) - { - foreach (ChatToolCall toolCall in response.ToolCalls) - { - if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) - { - var callContent = ParseCallContentFromBinaryData(toolCall.FunctionArguments, toolCall.Id, toolCall.FunctionName); - callContent.RawRepresentation = toolCall; - - returnMessage.Contents.Add(callContent); - } - } - } - - // Wrap the content in a ChatCompletion to return. - var completion = new ChatCompletion([returnMessage]) - { - RawRepresentation = response, - CompletionId = response.Id, - CreatedAt = response.CreatedAt, - ModelId = response.Model, - FinishReason = ToFinishReason(response.FinishReason), - }; + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); - if (response.Usage is ChatTokenUsage tokenUsage) - { - completion.Usage = ToUsageDetails(tokenUsage); - } - - if (response.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) - { - (completion.AdditionalProperties ??= [])[nameof(response.ContentTokenLogProbabilities)] = contentTokenLogProbs; - } - - if (response.Refusal is string refusal) - { - (completion.AdditionalProperties ??= [])[nameof(response.Refusal)] = refusal; - } - - if (response.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) - { - (completion.AdditionalProperties ??= [])[nameof(response.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; - } - - if (response.SystemFingerprint is string systemFingerprint) - { - (completion.AdditionalProperties ??= [])[nameof(response.SystemFingerprint)] = systemFingerprint; - } + // Make the call to OpenAI. + var response = await _chatClient.CompleteChatAsync(openAIChatMessages, openAIOptions, cancellationToken).ConfigureAwait(false); - return completion; + return OpenAIModelMappers.FromOpenAIChatCompletion(response.Value, options); } /// - public async IAsyncEnumerable CompleteStreamingAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatMessages); - Dictionary? functionCallInfos = null; - ChatRole? streamedRole = null; - ChatFinishReason? finishReason = null; - StringBuilder? refusal = null; - string? completionId = null; - DateTimeOffset? createdAt = null; - string? modelId = null; - string? fingerprint = null; - - // Process each update as it arrives - await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in _chatClient.CompleteChatStreamingAsync( - ToOpenAIChatMessages(chatMessages), ToOpenAIOptions(options), cancellationToken).ConfigureAwait(false)) - { - // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? ToChatRole(role) : null; - finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? ToFinishReason(reason) : null; - completionId ??= chatCompletionUpdate.CompletionId; - createdAt ??= chatCompletionUpdate.CreatedAt; - modelId ??= chatCompletionUpdate.Model; - fingerprint ??= chatCompletionUpdate.SystemFingerprint; - - // Create the response content object. - StreamingChatCompletionUpdate completionUpdate = new() - { - CompletionId = chatCompletionUpdate.CompletionId, - CreatedAt = chatCompletionUpdate.CreatedAt, - FinishReason = finishReason, - ModelId = modelId, - RawRepresentation = chatCompletionUpdate, - Role = streamedRole, - }; - - // Populate it with any additional metadata from the OpenAI object. - if (chatCompletionUpdate.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) - { - (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.ContentTokenLogProbabilities)] = contentTokenLogProbs; - } - - if (chatCompletionUpdate.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) - { - (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; - } - - if (fingerprint is not null) - { - (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.SystemFingerprint)] = fingerprint; - } - - // Transfer over content update items. - if (chatCompletionUpdate.ContentUpdate is { Count: > 0 }) - { - foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate) - { - if (ToAIContent(contentPart) is AIContent aiContent) - { - completionUpdate.Contents.Add(aiContent); - } - } - } - - // Transfer over refusal updates. - if (chatCompletionUpdate.RefusalUpdate is not null) - { - _ = (refusal ??= new()).Append(chatCompletionUpdate.RefusalUpdate); - } - - // Transfer over tool call updates. - if (chatCompletionUpdate.ToolCallUpdates is { Count: > 0 } toolCallUpdates) - { - foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) - { - functionCallInfos ??= []; - if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) - { - functionCallInfos[toolCallUpdate.Index] = existing = new(); - } - - existing.CallId ??= toolCallUpdate.ToolCallId; - existing.Name ??= toolCallUpdate.FunctionName; - if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty) - { - _ = (existing.Arguments ??= new()).Append(update.ToString()); - } - } - } - - // Transfer over usage updates. - if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage) - { - var usageDetails = ToUsageDetails(tokenUsage); - completionUpdate.Contents.Add(new UsageContent(usageDetails)); - } - - // Now yield the item. - yield return completionUpdate; - } - - // Now that we've received all updates, combine any for function calls into a single item to yield. - if (functionCallInfos is not null) - { - StreamingChatCompletionUpdate completionUpdate = new() - { - CompletionId = completionId, - CreatedAt = createdAt, - FinishReason = finishReason, - ModelId = modelId, - Role = streamedRole, - }; + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); - foreach (var entry in functionCallInfos) - { - FunctionCallInfo fci = entry.Value; - if (!string.IsNullOrWhiteSpace(fci.Name)) - { - var callContent = ParseCallContentFromJsonString( - fci.Arguments?.ToString() ?? string.Empty, - fci.CallId!, - fci.Name!); - completionUpdate.Contents.Add(callContent); - } - } - - // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, - // add it to this function calling item. - if (refusal is not null) - { - (completionUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); - } - - // Propagate additional relevant metadata. - if (fingerprint is not null) - { - (completionUpdate.AdditionalProperties ??= [])[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = fingerprint; - } + // Make the call to OpenAI. + var chatCompletionUpdates = _chatClient.CompleteChatStreamingAsync(openAIChatMessages, openAIOptions, cancellationToken); - yield return completionUpdate; - } + return OpenAIModelMappers.FromOpenAIStreamingChatCompletionAsync(chatCompletionUpdates, cancellationToken); } /// @@ -335,376 +132,4 @@ void IDisposable.Dispose() { // Nothing to dispose. Implementation required for the IChatClient interface. } - - /// POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates. - private sealed class FunctionCallInfo - { - public string? CallId; - public string? Name; - public StringBuilder? Arguments; - } - - private static UsageDetails ToUsageDetails(ChatTokenUsage tokenUsage) - { - var destination = new UsageDetails - { - InputTokenCount = tokenUsage.InputTokenCount, - OutputTokenCount = tokenUsage.OutputTokenCount, - TotalTokenCount = tokenUsage.TotalTokenCount, - AdditionalCounts = new(), - }; - - if (tokenUsage.InputTokenDetails is ChatInputTokenUsageDetails inputDetails) - { - destination.AdditionalCounts.Add( - $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", - inputDetails.AudioTokenCount); - - destination.AdditionalCounts.Add( - $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", - inputDetails.CachedTokenCount); - } - - if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails outputDetails) - { - destination.AdditionalCounts.Add( - $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", - outputDetails.AudioTokenCount); - - destination.AdditionalCounts.Add( - $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", - outputDetails.ReasoningTokenCount); - } - - return destination; - } - - /// Converts an OpenAI role to an Extensions role. - private static ChatRole ToChatRole(ChatMessageRole role) => - role switch - { - ChatMessageRole.System => ChatRole.System, - ChatMessageRole.User => ChatRole.User, - ChatMessageRole.Assistant => ChatRole.Assistant, - ChatMessageRole.Tool => ChatRole.Tool, - _ => new ChatRole(role.ToString()), - }; - - /// Converts an OpenAI finish reason to an Extensions finish reason. - private static ChatFinishReason? ToFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => - finishReason?.ToString() is not string s ? null : - finishReason switch - { - OpenAI.Chat.ChatFinishReason.Stop => ChatFinishReason.Stop, - OpenAI.Chat.ChatFinishReason.Length => ChatFinishReason.Length, - OpenAI.Chat.ChatFinishReason.ContentFilter => ChatFinishReason.ContentFilter, - OpenAI.Chat.ChatFinishReason.ToolCalls or OpenAI.Chat.ChatFinishReason.FunctionCall => ChatFinishReason.ToolCalls, - _ => new ChatFinishReason(s), - }; - - /// Converts an extensions options instance to an OpenAI options instance. - private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) - { - ChatCompletionOptions result = new(); - - if (options is not null) - { - result.FrequencyPenalty = options.FrequencyPenalty; - result.MaxOutputTokenCount = options.MaxOutputTokens; - result.TopP = options.TopP; - result.PresencePenalty = options.PresencePenalty; - result.Temperature = options.Temperature; -#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. - result.Seed = options.Seed; -#pragma warning restore OPENAI001 - - if (options.StopSequences is { Count: > 0 } stopSequences) - { - foreach (string stopSequence in stopSequences) - { - result.StopSequences.Add(stopSequence); - } - } - - if (options.AdditionalProperties is { Count: > 0 } additionalProperties) - { - if (additionalProperties.TryGetValue(nameof(result.EndUserId), out string? endUserId)) - { - result.EndUserId = endUserId; - } - - if (additionalProperties.TryGetValue(nameof(result.IncludeLogProbabilities), out bool includeLogProbabilities)) - { - result.IncludeLogProbabilities = includeLogProbabilities; - } - - if (additionalProperties.TryGetValue(nameof(result.LogitBiases), out IDictionary? logitBiases)) - { - foreach (KeyValuePair kvp in logitBiases!) - { - result.LogitBiases[kvp.Key] = kvp.Value; - } - } - - if (additionalProperties.TryGetValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls)) - { - result.AllowParallelToolCalls = allowParallelToolCalls; - } - - if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) - { - result.TopLogProbabilityCount = topLogProbabilityCountInt; - } - - if (additionalProperties.TryGetValue(nameof(result.Metadata), out IDictionary? metadata)) - { - foreach (KeyValuePair kvp in metadata) - { - result.Metadata[kvp.Key] = kvp.Value; - } - } - - if (additionalProperties.TryGetValue(nameof(result.StoredOutputEnabled), out bool storeOutputEnabled)) - { - result.StoredOutputEnabled = storeOutputEnabled; - } - } - - if (options.Tools is { Count: > 0 } tools) - { - foreach (AITool tool in tools) - { - if (tool is AIFunction af) - { - result.Tools.Add(ToOpenAIChatTool(af)); - } - } - - switch (options.ToolMode) - { - case AutoChatToolMode: - result.ToolChoice = ChatToolChoice.CreateAutoChoice(); - break; - - case RequiredChatToolMode required: - result.ToolChoice = required.RequiredFunctionName is null ? - ChatToolChoice.CreateRequiredChoice() : - ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName); - break; - } - } - - if (options.ResponseFormat is ChatResponseFormatText) - { - result.ResponseFormat = OpenAI.Chat.ChatResponseFormat.CreateTextFormat(); - } - else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) - { - result.ResponseFormat = jsonFormat.Schema is { } jsonSchema ? - OpenAI.Chat.ChatResponseFormat.CreateJsonSchemaFormat( - jsonFormat.SchemaName ?? "json_schema", - BinaryData.FromBytes( - JsonSerializer.SerializeToUtf8Bytes(jsonSchema, OpenAIJsonContext.Default.JsonElement)), - jsonFormat.SchemaDescription) : - OpenAI.Chat.ChatResponseFormat.CreateJsonObjectFormat(); - } - } - - return result; - } - - /// Converts an Extensions function to an OpenAI chat tool. - private static ChatTool ToOpenAIChatTool(AIFunction aiFunction) - { - bool? strict = - aiFunction.Metadata.AdditionalProperties.TryGetValue("Strict", out object? strictObj) && - strictObj is bool strictValue ? - strictValue : null; - - BinaryData resultParameters = OpenAIChatToolJson.ZeroFunctionParametersSchema; - - var parameters = aiFunction.Metadata.Parameters; - if (parameters is { Count: > 0 }) - { - OpenAIChatToolJson tool = new(); - - foreach (AIFunctionParameterMetadata parameter in parameters) - { - tool.Properties.Add(parameter.Name, parameter.Schema is JsonElement e ? e : _defaultParameterSchema); - - if (parameter.IsRequired) - { - tool.Required.Add(parameter.Name); - } - } - - resultParameters = BinaryData.FromBytes( - JsonSerializer.SerializeToUtf8Bytes(tool, OpenAIJsonContext.Default.OpenAIChatToolJson)); - } - - return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict); - } - - /// Used to create the JSON payload for an OpenAI chat tool description. - internal sealed class OpenAIChatToolJson - { - /// Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function. - public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray()); - - [JsonPropertyName("type")] - public string Type { get; set; } = "object"; - - [JsonPropertyName("required")] - public List Required { get; set; } = []; - - [JsonPropertyName("properties")] - public Dictionary Properties { get; set; } = []; - } - - /// Creates an from a . - /// The content part to convert into a content. - /// The constructed , or null if the content part could not be converted. - private static AIContent? ToAIContent(ChatMessageContentPart contentPart) - { - AIContent? aiContent = null; - - if (contentPart.Kind == ChatMessageContentPartKind.Text) - { - aiContent = new TextContent(contentPart.Text); - } - else if (contentPart.Kind == ChatMessageContentPartKind.Image) - { - ImageContent? imageContent; - aiContent = imageContent = - contentPart.ImageUri is not null ? new ImageContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : - contentPart.ImageBytes is not null ? new ImageContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : - null; - - if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) - { - (imageContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; - } - } - - if (aiContent is not null) - { - if (contentPart.Refusal is string refusal) - { - (aiContent.AdditionalProperties ??= [])[nameof(contentPart.Refusal)] = refusal; - } - - aiContent.RawRepresentation = contentPart; - } - - return aiContent; - } - - /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. - private IEnumerable ToOpenAIChatMessages(IEnumerable inputs) - { - // Maps all of the M.E.AI types to the corresponding OpenAI types. - // Unrecognized or non-processable content is ignored. - - foreach (ChatMessage input in inputs) - { - if (input.Role == ChatRole.System || input.Role == ChatRole.User) - { - var parts = GetContentParts(input.Contents); - yield return input.Role == ChatRole.System ? - new SystemChatMessage(parts) { ParticipantName = input.AuthorName } : - new UserChatMessage(parts) { ParticipantName = input.AuthorName }; - } - else if (input.Role == ChatRole.Tool) - { - foreach (AIContent item in input.Contents) - { - if (item is FunctionResultContent resultContent) - { - string? result = resultContent.Result as string; - if (result is null && resultContent.Result is not null) - { - try - { - result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object))); - } - catch (NotSupportedException) - { - // If the type can't be serialized, skip it. - } - } - - yield return new ToolChatMessage(resultContent.CallId, result ?? string.Empty); - } - } - } - else if (input.Role == ChatRole.Assistant) - { - AssistantChatMessage message = new(GetContentParts(input.Contents)) - { - ParticipantName = input.AuthorName - }; - - foreach (var content in input.Contents) - { - if (content is FunctionCallContent { CallId: not null } callRequest) - { - message.ToolCalls.Add( - ChatToolCall.CreateFunctionToolCall( - callRequest.CallId, - callRequest.Name, - new(JsonSerializer.SerializeToUtf8Bytes( - callRequest.Arguments, - ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary)))))); - } - } - - if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true) - { - message.Refusal = refusal; - } - - yield return message; - } - } - } - - /// Converts a list of to a list of . - private static List GetContentParts(IList contents) - { - List parts = []; - foreach (var content in contents) - { - switch (content) - { - case TextContent textContent: - parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text)); - break; - - case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data: - parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType)); - break; - - case ImageContent imageContent when imageContent.Uri is string uri: - parts.Add(ChatMessageContentPart.CreateImagePart(new Uri(uri))); - break; - } - } - - if (parts.Count == 0) - { - parts.Add(ChatMessageContentPart.CreateTextPart(string.Empty)); - } - - return parts; - } - - private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) => - FunctionCallContent.CreateFromParsedArguments(json, callId, name, - argumentParser: static json => JsonSerializer.Deserialize(json, - (JsonTypeInfo>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary)))!); - - private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8Json, string callId, string name) => - FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name, - argumentParser: static json => JsonSerializer.Deserialize(json, - (JsonTypeInfo>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary)))!); } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatCompletionRequest.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatCompletionRequest.cs new file mode 100644 index 00000000000..dba0e5ecbf8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatCompletionRequest.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents an OpenAI chat completion request deserialized as Microsoft.Extension.AI models. +/// +public sealed class OpenAIChatCompletionRequest +{ + /// + /// Gets the chat messages specified in the completion request. + /// + public required IList Messages { get; init; } + + /// + /// Gets the chat options governing the completion request. + /// + public required ChatOptions Options { get; init; } + + /// + /// Gets a value indicating whether the completion response should be streamed. + /// + public bool Stream { get; init; } + + /// + /// Gets the model id requested by the chat completion. + /// + public string? ModelId { get; init; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs index 9cd075e1d04..69f610b4818 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Text.Json; using System.Text.Json.Serialization; @@ -11,6 +12,7 @@ namespace Microsoft.Extensions.AI; UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, WriteIndented = true)] -[JsonSerializable(typeof(OpenAIChatClient.OpenAIChatToolJson))] [JsonSerializable(typeof(OpenAIRealtimeExtensions.ConversationFunctionToolParametersSchema))] +[JsonSerializable(typeof(OpenAIModelMappers.OpenAIChatToolJson))] +[JsonSerializable(typeof(IDictionary))] internal sealed partial class OpenAIJsonContext : JsonSerializerContext; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs new file mode 100644 index 00000000000..9f35727cf80 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -0,0 +1,610 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; +using OpenAI.Chat; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable CA1859 // Use concrete types when possible for improved performance +#pragma warning disable S1067 // Expressions should not be too complex + +namespace Microsoft.Extensions.AI; + +internal static partial class OpenAIModelMappers +{ + private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement; + + public static OpenAI.Chat.ChatCompletion ToOpenAIChatCompletion(ChatCompletion chatCompletion, JsonSerializerOptions options) + { + _ = Throw.IfNull(chatCompletion); + + if (chatCompletion.Choices.Count > 1) + { + throw new NotSupportedException("Creating OpenAI ChatCompletion models with multiple choices is currently not supported."); + } + + List? toolCalls = null; + foreach (AIContent content in chatCompletion.Message.Contents) + { + if (content is FunctionCallContent callRequest) + { + toolCalls ??= []; + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + } + } + + OpenAI.Chat.ChatTokenUsage? chatTokenUsage = null; + if (chatCompletion.Usage is UsageDetails usageDetails) + { + chatTokenUsage = ToOpenAIUsage(usageDetails); + } + + return OpenAIChatModelFactory.ChatCompletion( + id: chatCompletion.CompletionId, + model: chatCompletion.ModelId, + createdAt: chatCompletion.CreatedAt ?? default, + role: ToOpenAIChatRole(chatCompletion.Message.Role).Value, + finishReason: ToOpenAIFinishReason(chatCompletion.FinishReason), + content: new(ToOpenAIChatContent(chatCompletion.Message.Contents)), + toolCalls: toolCalls, + refusal: chatCompletion.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.ChatCompletion.Refusal)), + contentTokenLogProbabilities: chatCompletion.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.ChatCompletion.ContentTokenLogProbabilities)), + refusalTokenLogProbabilities: chatCompletion.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.ChatCompletion.RefusalTokenLogProbabilities)), + systemFingerprint: chatCompletion.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)), + usage: chatTokenUsage); + } + + public static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion openAICompletion, ChatOptions? options) + { + _ = Throw.IfNull(openAICompletion); + + // Create the return message. + ChatMessage returnMessage = new() + { + RawRepresentation = openAICompletion, + Role = FromOpenAIChatRole(openAICompletion.Role), + }; + + // Populate its content from those in the OpenAI response content. + foreach (ChatMessageContentPart contentPart in openAICompletion.Content) + { + if (ToAIContent(contentPart) is AIContent aiContent) + { + returnMessage.Contents.Add(aiContent); + } + } + + // Also manufacture function calling content items from any tool calls in the response. + if (options?.Tools is { Count: > 0 }) + { + foreach (ChatToolCall toolCall in openAICompletion.ToolCalls) + { + if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) + { + var callContent = ParseCallContentFromBinaryData(toolCall.FunctionArguments, toolCall.Id, toolCall.FunctionName); + callContent.RawRepresentation = toolCall; + + returnMessage.Contents.Add(callContent); + } + } + } + + // Wrap the content in a ChatCompletion to return. + var completion = new ChatCompletion([returnMessage]) + { + RawRepresentation = openAICompletion, + CompletionId = openAICompletion.Id, + CreatedAt = openAICompletion.CreatedAt, + ModelId = openAICompletion.Model, + FinishReason = FromOpenAIFinishReason(openAICompletion.FinishReason), + }; + + if (openAICompletion.Usage is ChatTokenUsage tokenUsage) + { + completion.Usage = FromOpenAIUsage(tokenUsage); + } + + if (openAICompletion.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(openAICompletion.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (openAICompletion.Refusal is string refusal) + { + (completion.AdditionalProperties ??= [])[nameof(openAICompletion.Refusal)] = refusal; + } + + if (openAICompletion.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(openAICompletion.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (openAICompletion.SystemFingerprint is string systemFingerprint) + { + (completion.AdditionalProperties ??= [])[nameof(openAICompletion.SystemFingerprint)] = systemFingerprint; + } + + return completion; + } + + public static ChatOptions FromOpenAIOptions(OpenAI.Chat.ChatCompletionOptions? options) + { + ChatOptions result = new(); + + if (options is not null) + { + result.ModelId = _getModelIdAccessor.Invoke(options, null)?.ToString(); + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxOutputTokens = options.MaxOutputTokenCount; + result.TopP = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + result.Seed = options.Seed; +#pragma warning restore OPENAI001 + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + result.StopSequences = [.. stopSequences]; + } + + if (options.EndUserId is string endUserId) + { + (result.AdditionalProperties ??= [])[nameof(options.EndUserId)] = endUserId; + } + + if (options.IncludeLogProbabilities is bool includeLogProbabilities) + { + (result.AdditionalProperties ??= [])[nameof(options.IncludeLogProbabilities)] = includeLogProbabilities; + } + + if (options.LogitBiases is { Count: > 0 } logitBiases) + { + (result.AdditionalProperties ??= [])[nameof(options.LogitBiases)] = new Dictionary(logitBiases); + } + + if (options.AllowParallelToolCalls is bool allowParallelToolCalls) + { + (result.AdditionalProperties ??= [])[nameof(options.AllowParallelToolCalls)] = allowParallelToolCalls; + } + + if (options.TopLogProbabilityCount is int topLogProbabilityCount) + { + (result.AdditionalProperties ??= [])[nameof(options.TopLogProbabilityCount)] = topLogProbabilityCount; + } + + if (options.Metadata is IDictionary { Count: > 0 } metadata) + { + (result.AdditionalProperties ??= [])[nameof(options.Metadata)] = new Dictionary(metadata); + } + + if (options.StoredOutputEnabled is bool storedOutputEnabled) + { + (result.AdditionalProperties ??= [])[nameof(options.StoredOutputEnabled)] = storedOutputEnabled; + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (ChatTool tool in tools) + { + result.Tools ??= []; + result.Tools.Add(FromOpenAIChatTool(tool)); + } + + using var toolChoiceJson = JsonDocument.Parse(JsonModelHelpers.Serialize(options.ToolChoice).ToMemory()); + JsonElement jsonElement = toolChoiceJson.RootElement; + switch (jsonElement.ValueKind) + { + case JsonValueKind.String: + result.ToolMode = jsonElement.GetString() switch + { + "required" => ChatToolMode.RequireAny, + _ => ChatToolMode.Auto, + }; + + break; + case JsonValueKind.Object: + if (jsonElement.TryGetProperty("function", out JsonElement functionElement)) + { + result.ToolMode = ChatToolMode.RequireSpecific(functionElement.GetString()!); + } + + break; + } + } + } + + return result; + } + + /// Converts an extensions options instance to an OpenAI options instance. + public static OpenAI.Chat.ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) + { + ChatCompletionOptions result = new(); + + if (options is not null) + { + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxOutputTokenCount = options.MaxOutputTokens; + result.TopP = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + result.Seed = options.Seed; +#pragma warning restore OPENAI001 + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + foreach (string stopSequence in stopSequences) + { + result.StopSequences.Add(stopSequence); + } + } + + if (options.AdditionalProperties is { Count: > 0 } additionalProperties) + { + if (additionalProperties.TryGetValue(nameof(result.EndUserId), out string? endUserId)) + { + result.EndUserId = endUserId; + } + + if (additionalProperties.TryGetValue(nameof(result.IncludeLogProbabilities), out bool includeLogProbabilities)) + { + result.IncludeLogProbabilities = includeLogProbabilities; + } + + if (additionalProperties.TryGetValue(nameof(result.LogitBiases), out IDictionary? logitBiases)) + { + foreach (KeyValuePair kvp in logitBiases!) + { + result.LogitBiases[kvp.Key] = kvp.Value; + } + } + + if (additionalProperties.TryGetValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls)) + { + result.AllowParallelToolCalls = allowParallelToolCalls; + } + + if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) + { + result.TopLogProbabilityCount = topLogProbabilityCountInt; + } + + if (additionalProperties.TryGetValue(nameof(result.Metadata), out IDictionary? metadata)) + { + foreach (KeyValuePair kvp in metadata) + { + result.Metadata[kvp.Key] = kvp.Value; + } + } + + if (additionalProperties.TryGetValue(nameof(result.StoredOutputEnabled), out bool storeOutputEnabled)) + { + result.StoredOutputEnabled = storeOutputEnabled; + } + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (AITool tool in tools) + { + if (tool is AIFunction af) + { + result.Tools.Add(ToOpenAIChatTool(af)); + } + } + + switch (options.ToolMode) + { + case AutoChatToolMode: + result.ToolChoice = ChatToolChoice.CreateAutoChoice(); + break; + + case RequiredChatToolMode required: + result.ToolChoice = required.RequiredFunctionName is null ? + ChatToolChoice.CreateRequiredChoice() : + ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName); + break; + } + } + + if (options.ResponseFormat is ChatResponseFormatText) + { + result.ResponseFormat = OpenAI.Chat.ChatResponseFormat.CreateTextFormat(); + } + else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + result.ResponseFormat = jsonFormat.Schema is { } jsonSchema ? + OpenAI.Chat.ChatResponseFormat.CreateJsonSchemaFormat( + jsonFormat.SchemaName ?? "json_schema", + BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes(jsonSchema, OpenAIJsonContext.Default.JsonElement)), + jsonFormat.SchemaDescription) : + OpenAI.Chat.ChatResponseFormat.CreateJsonObjectFormat(); + } + } + + return result; + } + + private static AITool FromOpenAIChatTool(ChatTool chatTool) + { + AdditionalPropertiesDictionary additionalProperties = new(); + if (chatTool.FunctionSchemaIsStrict is bool strictValue) + { + additionalProperties["Strict"] = strictValue; + } + + OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!; + List parameters = new(openAiChatTool.Properties.Count); + foreach (KeyValuePair property in openAiChatTool.Properties) + { + parameters.Add(new(property.Key) + { + Schema = property.Value, + IsRequired = openAiChatTool.Required.Contains(property.Key), + }); + } + + AIFunctionMetadata metadata = new(chatTool.FunctionName) + { + Description = chatTool.FunctionDescription, + AdditionalProperties = additionalProperties, + Parameters = parameters, + ReturnParameter = new() + { + Description = "Return parameter", + Schema = _defaultParameterSchema, + } + }; + + return new MetadataOnlyAIFunction(metadata); + } + + private sealed class MetadataOnlyAIFunction(AIFunctionMetadata metadata) : AIFunction + { + public override AIFunctionMetadata Metadata => metadata; + protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) => + throw new InvalidOperationException($"The AI function '{metadata.Name}' does not support being invoked."); + } + + /// Converts an Extensions function to an OpenAI chat tool. + private static ChatTool ToOpenAIChatTool(AIFunction aiFunction) + { + bool? strict = + aiFunction.Metadata.AdditionalProperties.TryGetValue("Strict", out object? strictObj) && + strictObj is bool strictValue ? + strictValue : null; + + BinaryData resultParameters = OpenAIChatToolJson.ZeroFunctionParametersSchema; + + var parameters = aiFunction.Metadata.Parameters; + if (parameters is { Count: > 0 }) + { + OpenAIChatToolJson tool = new(); + + foreach (AIFunctionParameterMetadata parameter in parameters) + { + tool.Properties.Add(parameter.Name, parameter.Schema is JsonElement e ? e : _defaultParameterSchema); + + if (parameter.IsRequired) + { + _ = tool.Required.Add(parameter.Name); + } + } + + resultParameters = BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes(tool, OpenAIJsonContext.Default.OpenAIChatToolJson)); + } + + return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict); + } + + private static UsageDetails FromOpenAIUsage(ChatTokenUsage tokenUsage) + { + var destination = new UsageDetails + { + InputTokenCount = tokenUsage.InputTokenCount, + OutputTokenCount = tokenUsage.OutputTokenCount, + TotalTokenCount = tokenUsage.TotalTokenCount, + AdditionalCounts = new(), + }; + + if (tokenUsage.InputTokenDetails is ChatInputTokenUsageDetails inputDetails) + { + destination.AdditionalCounts.Add( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", + inputDetails.AudioTokenCount); + + destination.AdditionalCounts.Add( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", + inputDetails.CachedTokenCount); + } + + if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails outputDetails) + { + destination.AdditionalCounts.Add( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", + outputDetails.AudioTokenCount); + + destination.AdditionalCounts.Add( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", + outputDetails.ReasoningTokenCount); + } + + return destination; + } + + private static ChatTokenUsage ToOpenAIUsage(UsageDetails usageDetails) + { + ChatOutputTokenUsageDetails? outputTokenUsageDetails = null; + ChatInputTokenUsageDetails? inputTokenUsageDetails = null; + + if (usageDetails.AdditionalCounts is { Count: > 0 } additionalCounts) + { + int? inputAudioTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", + out int value) ? value : null; + + int? inputCachedTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", + out value) ? value : null; + + int? outputAudioTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", + out value) ? value : null; + + int? outputReasoningTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", + out value) ? value : null; + + if (inputAudioTokenCount is not null || inputCachedTokenCount is not null) + { + inputTokenUsageDetails = OpenAIChatModelFactory.ChatInputTokenUsageDetails( + audioTokenCount: inputAudioTokenCount ?? 0, + cachedTokenCount: inputCachedTokenCount ?? 0); + } + + if (outputAudioTokenCount is not null || outputReasoningTokenCount is not null) + { + outputTokenUsageDetails = OpenAIChatModelFactory.ChatOutputTokenUsageDetails( + audioTokenCount: outputAudioTokenCount ?? 0, + reasoningTokenCount: outputReasoningTokenCount ?? 0); + } + } + + return OpenAIChatModelFactory.ChatTokenUsage( + inputTokenCount: usageDetails.InputTokenCount ?? 0, + outputTokenCount: usageDetails.OutputTokenCount ?? 0, + totalTokenCount: usageDetails.TotalTokenCount ?? 0, + outputTokenDetails: outputTokenUsageDetails, + inputTokenDetails: inputTokenUsageDetails); + } + + /// Converts an OpenAI role to an Extensions role. + private static ChatRole FromOpenAIChatRole(ChatMessageRole role) => + role switch + { + ChatMessageRole.System => ChatRole.System, + ChatMessageRole.User => ChatRole.User, + ChatMessageRole.Assistant => ChatRole.Assistant, + ChatMessageRole.Tool => ChatRole.Tool, + _ => new ChatRole(role.ToString()), + }; + + /// Converts an Extensions role to an OpenAI role. + [return: NotNullIfNotNull("role")] + private static ChatMessageRole? ToOpenAIChatRole(ChatRole? role) => + role is null ? null : + role == ChatRole.System ? ChatMessageRole.System : + role == ChatRole.User ? ChatMessageRole.User : + role == ChatRole.Assistant ? ChatMessageRole.Assistant : + role == ChatRole.Tool ? ChatMessageRole.Tool : ChatMessageRole.User; + + /// Creates an from a . + /// The content part to convert into a content. + /// The constructed , or null if the content part could not be converted. + private static AIContent? ToAIContent(ChatMessageContentPart contentPart) + { + AIContent? aiContent = null; + + if (contentPart.Kind == ChatMessageContentPartKind.Text) + { + aiContent = new TextContent(contentPart.Text); + } + else if (contentPart.Kind == ChatMessageContentPartKind.Image) + { + ImageContent? imageContent; + aiContent = imageContent = + contentPart.ImageUri is not null ? new ImageContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : + contentPart.ImageBytes is not null ? new ImageContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : + null; + + if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) + { + (imageContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; + } + } + + if (aiContent is not null) + { + if (contentPart.Refusal is string refusal) + { + (aiContent.AdditionalProperties ??= [])[nameof(contentPart.Refusal)] = refusal; + } + + aiContent.RawRepresentation = contentPart; + } + + return aiContent; + } + + /// Converts an OpenAI finish reason to an Extensions finish reason. + private static ChatFinishReason? FromOpenAIFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => + finishReason?.ToString() is not string s ? null : + finishReason switch + { + OpenAI.Chat.ChatFinishReason.Stop => ChatFinishReason.Stop, + OpenAI.Chat.ChatFinishReason.Length => ChatFinishReason.Length, + OpenAI.Chat.ChatFinishReason.ContentFilter => ChatFinishReason.ContentFilter, + OpenAI.Chat.ChatFinishReason.ToolCalls or OpenAI.Chat.ChatFinishReason.FunctionCall => ChatFinishReason.ToolCalls, + _ => new ChatFinishReason(s), + }; + + /// Converts an Extensions finish reason to an OpenAI finish reason. + private static OpenAI.Chat.ChatFinishReason ToOpenAIFinishReason(ChatFinishReason? finishReason) => + finishReason == ChatFinishReason.Length ? OpenAI.Chat.ChatFinishReason.Length : + finishReason == ChatFinishReason.ContentFilter ? OpenAI.Chat.ChatFinishReason.ContentFilter : + finishReason == ChatFinishReason.ToolCalls ? OpenAI.Chat.ChatFinishReason.ToolCalls : + OpenAI.Chat.ChatFinishReason.Stop; + + private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) => + FunctionCallContent.CreateFromParsedArguments(json, callId, name, + argumentParser: static json => JsonSerializer.Deserialize(json, OpenAIJsonContext.Default.IDictionaryStringObject)!); + + private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8Json, string callId, string name) => + FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name, + argumentParser: static json => JsonSerializer.Deserialize(json, OpenAIJsonContext.Default.IDictionaryStringObject)!); + + private static T? GetValueOrDefault(this AdditionalPropertiesDictionary? dict, string key) => + dict?.TryGetValue(key, out T? value) is true ? value : default; + + /// Used to create the JSON payload for an OpenAI chat tool description. + public sealed class OpenAIChatToolJson + { + /// Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function. + public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray()); + + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + [JsonPropertyName("required")] + public HashSet Required { get; set; } = []; + + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + } + + /// POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates. + private sealed class FunctionCallInfo + { + public string? CallId; + public string? Name; + public StringBuilder? Arguments; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs new file mode 100644 index 00000000000..e8193df24d5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs @@ -0,0 +1,255 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable SA1204 // Static elements should appear before instance elements + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using OpenAI.Chat; + +namespace Microsoft.Extensions.AI; + +internal static partial class OpenAIModelMappers +{ + public static OpenAIChatCompletionRequest FromOpenAIChatCompletionRequest(OpenAI.Chat.ChatCompletionOptions chatCompletionOptions) + { + ChatOptions chatOptions = FromOpenAIOptions(chatCompletionOptions); + IList messages = FromOpenAIChatMessages(_getMessagesAccessor(chatCompletionOptions)).ToList(); + return new() + { + Messages = messages, + Options = chatOptions, + ModelId = chatOptions.ModelId, + Stream = _getStreamAccessor(chatCompletionOptions) ?? false, + }; + } + + public static IEnumerable FromOpenAIChatMessages(IEnumerable inputs) + { + // Maps all of the OpenAI types to the corresponding M.E.AI types. + // Unrecognized or non-processable content is ignored. + + Dictionary? functionCalls = null; + + foreach (OpenAI.Chat.ChatMessage input in inputs) + { + switch (input) + { + case SystemChatMessage systemMessage: + yield return new ChatMessage + { + Role = ChatRole.System, + AuthorName = systemMessage.ParticipantName, + Contents = FromOpenAIChatContent(systemMessage.Content), + }; + break; + + case UserChatMessage userMessage: + yield return new ChatMessage + { + Role = ChatRole.User, + AuthorName = userMessage.ParticipantName, + Contents = FromOpenAIChatContent(userMessage.Content), + }; + break; + + case ToolChatMessage toolMessage: + string textContent = string.Join(string.Empty, toolMessage.Content.Where(part => part.Kind is ChatMessageContentPartKind.Text).Select(part => part.Text)); + object? result = textContent; + if (!string.IsNullOrEmpty(textContent)) + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + result = JsonSerializer.Deserialize(textContent, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))); + } + catch + { + // If the content can't be deserialized, leave it as a string. + } +#pragma warning restore CA1031 // Do not catch general exception types + } + + string functionName = functionCalls?.TryGetValue(toolMessage.ToolCallId, out string? name) is true ? name : string.Empty; + yield return new ChatMessage + { + Role = ChatRole.Tool, + Contents = new AIContent[] { new FunctionResultContent(toolMessage.ToolCallId, functionName, result) }, + }; + break; + + case AssistantChatMessage assistantMessage: + + ChatMessage message = new() + { + Role = ChatRole.Assistant, + AuthorName = assistantMessage.ParticipantName, + Contents = FromOpenAIChatContent(assistantMessage.Content), + }; + + foreach (ChatToolCall toolCall in assistantMessage.ToolCalls) + { + if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) + { + var callContent = ParseCallContentFromBinaryData(toolCall.FunctionArguments, toolCall.Id, toolCall.FunctionName); + callContent.RawRepresentation = toolCall; + + message.Contents.Add(callContent); + (functionCalls ??= new()).Add(toolCall.Id, toolCall.FunctionName); + } + } + + if (assistantMessage.Refusal is not null) + { + message.AdditionalProperties ??= []; + message.AdditionalProperties.Add(nameof(assistantMessage.Refusal), assistantMessage.Refusal); + } + + yield return message; + break; + } + } + } + + /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. + public static IEnumerable ToOpenAIChatMessages(IEnumerable inputs, JsonSerializerOptions options) + { + // Maps all of the M.E.AI types to the corresponding OpenAI types. + // Unrecognized or non-processable content is ignored. + + foreach (ChatMessage input in inputs) + { + if (input.Role == ChatRole.System || input.Role == ChatRole.User) + { + var parts = ToOpenAIChatContent(input.Contents); + yield return input.Role == ChatRole.System ? + new SystemChatMessage(parts) { ParticipantName = input.AuthorName } : + new UserChatMessage(parts) { ParticipantName = input.AuthorName }; + } + else if (input.Role == ChatRole.Tool) + { + foreach (AIContent item in input.Contents) + { + if (item is FunctionResultContent resultContent) + { + string? result = resultContent.Result as string; + if (result is null && resultContent.Result is not null) + { + try + { + result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object))); + } + catch (NotSupportedException) + { + // If the type can't be serialized, skip it. + } + } + + yield return new ToolChatMessage(resultContent.CallId, result ?? string.Empty); + } + } + } + else if (input.Role == ChatRole.Assistant) + { + AssistantChatMessage message = new(ToOpenAIChatContent(input.Contents)) + { + ParticipantName = input.AuthorName + }; + + foreach (var content in input.Contents) + { + if (content is FunctionCallContent callRequest) + { + message.ToolCalls.Add( + ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + } + } + + if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true) + { + message.Refusal = refusal; + } + + yield return message; + } + } + } + + private static List FromOpenAIChatContent(IList openAiMessageContentParts) + { + List contents = new(); + foreach (var openAiContentPart in openAiMessageContentParts) + { + switch (openAiContentPart.Kind) + { + case ChatMessageContentPartKind.Text: + contents.Add(new TextContent(openAiContentPart.Text)); + break; + + case ChatMessageContentPartKind.Image when (openAiContentPart.ImageBytes is { } bytes): + contents.Add(new ImageContent(bytes.ToArray(), openAiContentPart.ImageBytesMediaType)); + break; + + case ChatMessageContentPartKind.Image: + contents.Add(new ImageContent(openAiContentPart.ImageUri?.ToString() ?? string.Empty)); + break; + + } + } + + return contents; + } + + /// Converts a list of to a list of . + private static List ToOpenAIChatContent(IList contents) + { + List parts = []; + foreach (var content in contents) + { + switch (content) + { + case TextContent textContent: + parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text)); + break; + + case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data: + parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType)); + break; + + case ImageContent imageContent when imageContent.Uri is string uri: + parts.Add(ChatMessageContentPart.CreateImagePart(new Uri(uri))); + break; + } + } + + if (parts.Count == 0) + { + parts.Add(ChatMessageContentPart.CreateTextPart(string.Empty)); + } + + return parts; + } + +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + private static readonly Func> _getMessagesAccessor = + (Func>) + typeof(ChatCompletionOptions).GetMethod("get_Messages", BindingFlags.NonPublic | BindingFlags.Instance)! + .CreateDelegate(typeof(Func>))!; + + private static readonly Func _getStreamAccessor = + (Func) + typeof(ChatCompletionOptions).GetMethod("get_Stream", BindingFlags.NonPublic | BindingFlags.Instance)! + .CreateDelegate(typeof(Func))!; + + private static readonly MethodInfo _getModelIdAccessor = + typeof(ChatCompletionOptions).GetMethod("get_Model", BindingFlags.NonPublic | BindingFlags.Instance)!; +#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs new file mode 100644 index 00000000000..9fe6fa3fada --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMappers.StreamingChatCompletion.cs @@ -0,0 +1,205 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable CA1859 // Use concrete types when possible for improved performance + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using OpenAI.Chat; + +namespace Microsoft.Extensions.AI; + +internal static partial class OpenAIModelMappers +{ + public static async IAsyncEnumerable ToOpenAIStreamingChatCompletionAsync( + IAsyncEnumerable chatCompletions, + JsonSerializerOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var chatCompletionUpdate in chatCompletions.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + List? toolCallUpdates = null; + ChatTokenUsage? chatTokenUsage = null; + + foreach (var content in chatCompletionUpdate.Contents) + { + if (content is FunctionCallContent functionCallContent) + { + toolCallUpdates ??= []; + toolCallUpdates.Add(OpenAIChatModelFactory.StreamingChatToolCallUpdate( + index: toolCallUpdates.Count, + toolCallId: functionCallContent.CallId, + functionName: functionCallContent.Name, + functionArgumentsUpdate: new(JsonSerializer.SerializeToUtf8Bytes(functionCallContent.Arguments, options.GetTypeInfo(typeof(IDictionary)))))); + } + else if (content is UsageContent usageContent) + { + chatTokenUsage = ToOpenAIUsage(usageContent.Details); + } + } + + yield return OpenAIChatModelFactory.StreamingChatCompletionUpdate( + completionId: chatCompletionUpdate.CompletionId, + model: chatCompletionUpdate.ModelId, + createdAt: chatCompletionUpdate.CreatedAt ?? default, + role: ToOpenAIChatRole(chatCompletionUpdate.Role), + finishReason: ToOpenAIFinishReason(chatCompletionUpdate.FinishReason), + contentUpdate: [.. ToOpenAIChatContent(chatCompletionUpdate.Contents)], + toolCallUpdates: toolCallUpdates, + refusalUpdate: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalUpdate)), + contentTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.ContentTokenLogProbabilities)), + refusalTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalTokenLogProbabilities)), + systemFingerprint: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.SystemFingerprint)), + usage: chatTokenUsage); + } + } + + public static async IAsyncEnumerable FromOpenAIStreamingChatCompletionAsync( + IAsyncEnumerable chatCompletionUpdates, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Dictionary? functionCallInfos = null; + ChatRole? streamedRole = null; + ChatFinishReason? finishReason = null; + StringBuilder? refusal = null; + string? completionId = null; + DateTimeOffset? createdAt = null; + string? modelId = null; + string? fingerprint = null; + + // Process each update as it arrives + await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in chatCompletionUpdates.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; + completionId ??= chatCompletionUpdate.CompletionId; + createdAt ??= chatCompletionUpdate.CreatedAt; + modelId ??= chatCompletionUpdate.Model; + fingerprint ??= chatCompletionUpdate.SystemFingerprint; + + // Create the response content object. + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = chatCompletionUpdate.CompletionId, + CreatedAt = chatCompletionUpdate.CreatedAt, + FinishReason = finishReason, + ModelId = modelId, + RawRepresentation = chatCompletionUpdate, + Role = streamedRole, + }; + + // Populate it with any additional metadata from the OpenAI object. + if (chatCompletionUpdate.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (chatCompletionUpdate.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.SystemFingerprint)] = fingerprint; + } + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is { Count: > 0 }) + { + foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate) + { + if (ToAIContent(contentPart) is AIContent aiContent) + { + completionUpdate.Contents.Add(aiContent); + } + } + } + + // Transfer over refusal updates. + if (chatCompletionUpdate.RefusalUpdate is not null) + { + _ = (refusal ??= new()).Append(chatCompletionUpdate.RefusalUpdate); + } + + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdates is { Count: > 0 } toolCallUpdates) + { + foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) + { + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) + { + functionCallInfos[toolCallUpdate.Index] = existing = new(); + } + + existing.CallId ??= toolCallUpdate.ToolCallId; + existing.Name ??= toolCallUpdate.FunctionName; + if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty) + { + _ = (existing.Arguments ??= new()).Append(update.ToString()); + } + } + } + + // Transfer over usage updates. + if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage) + { + var usageDetails = FromOpenAIUsage(tokenUsage); + completionUpdate.Contents.Add(new UsageContent(usageDetails)); + } + + // Now yield the item. + yield return completionUpdate; + } + + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) + { + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = completionId, + CreatedAt = createdAt, + FinishReason = finishReason, + ModelId = modelId, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + var callContent = ParseCallContentFromJsonString( + fci.Arguments?.ToString() ?? string.Empty, + fci.CallId!, + fci.Name!); + completionUpdate.Contents.Add(callContent); + } + } + + // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, + // add it to this function calling item. + if (refusal is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); + } + + // Propagate additional relevant metadata. + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = fingerprint; + } + + yield return completionUpdate; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs new file mode 100644 index 00000000000..899a69630b8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Net.ServerSentEvents; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; +using OpenAI.Chat; + +namespace Microsoft.Extensions.AI; + +/// +/// Defines a set of helpers used to serialize Microsoft.Extensions.AI content using the OpenAI wire format. +/// +public static class OpenAISerializationHelpers +{ + /// + /// Deserializes a chat completion request in the OpenAI wire format into a pair of and values. + /// + /// The stream containing a message using the OpenAI wire format. + /// A token used to cancel the operation. + /// The deserialized list of chat messages and chat options. + public static async Task DeserializeChatCompletionRequestAsync( + Stream stream, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + + BinaryData binaryData = await BinaryData.FromStreamAsync(stream, cancellationToken).ConfigureAwait(false); + ChatCompletionOptions openAiChatOptions = JsonModelHelpers.Deserialize(binaryData); + return OpenAIModelMappers.FromOpenAIChatCompletionRequest(openAiChatOptions); + } + + /// + /// Serializes a Microsoft.Extensions.AI completion using the OpenAI wire format. + /// + /// The stream to write the value. + /// The chat completion to serialize. + /// The governing function call content serialization. + /// A token used to cancel the serialization operation. + /// A task tracking the serialization operation. + public static async Task SerializeAsync( + Stream stream, + ChatCompletion chatCompletion, + JsonSerializerOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + _ = Throw.IfNull(chatCompletion); + options ??= AIJsonUtilities.DefaultOptions; + + OpenAI.Chat.ChatCompletion openAiChatCompletion = OpenAIModelMappers.ToOpenAIChatCompletion(chatCompletion, options); + BinaryData binaryData = JsonModelHelpers.Serialize(openAiChatCompletion); + await stream.WriteAsync(binaryData.ToMemory(), cancellationToken).ConfigureAwait(false); + } + + /// + /// Serializes a Microsoft.Extensions.AI streaming completion using the OpenAI wire format. + /// + /// The stream to write the value. + /// The streaming chat completions to serialize. + /// The governing function call content serialization. + /// A token used to cancel the serialization operation. + /// A task tracking the serialization operation. + public static Task SerializeStreamingAsync( + Stream stream, + IAsyncEnumerable streamingChatCompletionUpdates, + JsonSerializerOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + _ = Throw.IfNull(streamingChatCompletionUpdates); + options ??= AIJsonUtilities.DefaultOptions; + + var mappedUpdates = OpenAIModelMappers.ToOpenAIStreamingChatCompletionAsync(streamingChatCompletionUpdates, options, cancellationToken); + return SseFormatter.WriteAsync(ToSseEventsAsync(mappedUpdates), stream, FormatAsSseEvent, cancellationToken); + + static async IAsyncEnumerable> ToSseEventsAsync(IAsyncEnumerable updates) + { + await foreach (var update in updates.ConfigureAwait(false)) + { + BinaryData binaryData = JsonModelHelpers.Serialize(update); + yield return new(binaryData); + } + + yield return new(_finalSseEvent); + } + + static void FormatAsSseEvent(SseItem sseItem, IBufferWriter writer) => + writer.Write(sseItem.Data.ToMemory().Span); + } + + private static readonly BinaryData _finalSseEvent = new("[DONE]"u8.ToArray()); +} diff --git a/src/Shared/ServerSentEvents/ArrayBuffer.cs b/src/Shared/ServerSentEvents/ArrayBuffer.cs new file mode 100644 index 00000000000..70331991ce7 --- /dev/null +++ b/src/Shared/ServerSentEvents/ArrayBuffer.cs @@ -0,0 +1,197 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#pragma warning disable SA1405 // Debug.Assert should provide message text +#pragma warning disable IDE0032 // Use auto property +#pragma warning disable S3358 // Ternary operators should not be nested +#pragma warning disable SA1202 // Elements should be ordered by access +#pragma warning disable S109 // Magic numbers should not be used + +namespace System.Net +{ + // Warning: Mutable struct! + // The purpose of this struct is to simplify buffer management. + // It manages a sliding buffer where bytes can be added at the end and removed at the beginning. + // [ActiveSpan/Memory] contains the current buffer contents; these bytes will be preserved + // (copied, if necessary) on any call to EnsureAvailableBytes. + // [AvailableSpan/Memory] contains the available bytes past the end of the current content, + // and can be written to in order to add data to the end of the buffer. + // Commit(byteCount) will extend the ActiveSpan by [byteCount] bytes into the AvailableSpan. + // Discard(byteCount) will discard [byteCount] bytes as the beginning of the ActiveSpan. + + [StructLayout(LayoutKind.Auto)] + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal struct ArrayBuffer : IDisposable + { + private readonly bool _usePool; + private byte[] _bytes; + private int _activeStart; + private int _availableStart; + + // Invariants: + // 0 <= _activeStart <= _availableStart <= bytes.Length + + public ArrayBuffer(int initialSize, bool usePool = false) + { + Debug.Assert(initialSize > 0 || usePool); + + _usePool = usePool; + _bytes = initialSize == 0 + ? Array.Empty() + : usePool ? ArrayPool.Shared.Rent(initialSize) : new byte[initialSize]; + _activeStart = 0; + _availableStart = 0; + } + + public ArrayBuffer(byte[] buffer) + { + Debug.Assert(buffer.Length > 0); + + _usePool = false; + _bytes = buffer; + _activeStart = 0; + _availableStart = 0; + } + + public void Dispose() + { + _activeStart = 0; + _availableStart = 0; + + byte[] array = _bytes; + _bytes = null!; + + if (array is not null) + { + ReturnBufferIfPooled(array); + } + } + + // This is different from Dispose as the instance remains usable afterwards (_bytes will not be null). + public void ClearAndReturnBuffer() + { + Debug.Assert(_usePool); + Debug.Assert(_bytes is not null); + + _activeStart = 0; + _availableStart = 0; + + byte[] bufferToReturn = _bytes!; + _bytes = Array.Empty(); + ReturnBufferIfPooled(bufferToReturn); + } + + public readonly int ActiveLength => _availableStart - _activeStart; + public readonly Span ActiveSpan => new Span(_bytes, _activeStart, _availableStart - _activeStart); + public readonly ReadOnlySpan ActiveReadOnlySpan => new ReadOnlySpan(_bytes, _activeStart, _availableStart - _activeStart); + public readonly Memory ActiveMemory => new Memory(_bytes, _activeStart, _availableStart - _activeStart); + + public readonly int AvailableLength => _bytes.Length - _availableStart; + public readonly Span AvailableSpan => _bytes.AsSpan(_availableStart); + public readonly Memory AvailableMemory => _bytes.AsMemory(_availableStart); + public readonly Memory AvailableMemorySliced(int length) => new Memory(_bytes, _availableStart, length); + + public readonly int Capacity => _bytes.Length; + public readonly int ActiveStartOffset => _activeStart; + + public readonly byte[] DangerousGetUnderlyingBuffer() => _bytes; + + public void Discard(int byteCount) + { + Debug.Assert(byteCount <= ActiveLength, $"Expected {byteCount} <= {ActiveLength}"); + _activeStart += byteCount; + + if (_activeStart == _availableStart) + { + _activeStart = 0; + _availableStart = 0; + } + } + + public void Commit(int byteCount) + { + Debug.Assert(byteCount <= AvailableLength); + _availableStart += byteCount; + } + + // Ensure at least [byteCount] bytes to write to. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void EnsureAvailableSpace(int byteCount) + { + if (byteCount > AvailableLength) + { + EnsureAvailableSpaceCore(byteCount); + } + } + + private void EnsureAvailableSpaceCore(int byteCount) + { + Debug.Assert(AvailableLength < byteCount); + + if (_bytes.Length == 0) + { + Debug.Assert(_usePool && _activeStart == 0 && _availableStart == 0); + _bytes = ArrayPool.Shared.Rent(byteCount); + return; + } + + int totalFree = _activeStart + AvailableLength; + if (byteCount <= totalFree) + { + // We can free up enough space by just shifting the bytes down, so do so. + Buffer.BlockCopy(_bytes, _activeStart, _bytes, 0, ActiveLength); + _availableStart = ActiveLength; + _activeStart = 0; + Debug.Assert(byteCount <= AvailableLength); + return; + } + + // Double the size of the buffer until we have enough space. + int desiredSize = ActiveLength + byteCount; + int newSize = _bytes.Length; + do + { + newSize *= 2; + } + while (newSize < desiredSize); + + byte[] newBytes = _usePool ? + ArrayPool.Shared.Rent(newSize) : + new byte[newSize]; + byte[] oldBytes = _bytes; + + if (ActiveLength != 0) + { + Buffer.BlockCopy(oldBytes, _activeStart, newBytes, 0, ActiveLength); + } + + _availableStart = ActiveLength; + _activeStart = 0; + + _bytes = newBytes; + ReturnBufferIfPooled(oldBytes); + + Debug.Assert(byteCount <= AvailableLength); + } + + public void Grow() + { + EnsureAvailableSpaceCore(AvailableLength + 1); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private readonly void ReturnBufferIfPooled(byte[] buffer) + { + // The buffer may be Array.Empty() + if (_usePool && buffer.Length > 0) + { + ArrayPool.Shared.Return(buffer); + } + } + } +} diff --git a/src/Shared/ServerSentEvents/Helpers.cs b/src/Shared/ServerSentEvents/Helpers.cs new file mode 100644 index 00000000000..c976162fa1e --- /dev/null +++ b/src/Shared/ServerSentEvents/Helpers.cs @@ -0,0 +1,127 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.ComponentModel; +using System.Diagnostics; +using System.Globalization; +#if !NET +using System.IO; +using System.Runtime.InteropServices; +#endif +using System.Text; +#if !NET +using System.Threading; +using System.Threading.Tasks; +#endif + +#pragma warning disable SA1405 // Debug.Assert should provide message text +#pragma warning disable S2333 // Redundant modifiers should not be used +#pragma warning disable SA1519 // Braces should not be omitted from multi-line child statement +#pragma warning disable LA0001 // Use Microsoft.Shared.Diagnostics.Throws for improved performance +#pragma warning disable LA0002 // Use Microsoft.Shared.Diagnostics.ToInvariantString for improved performance + +namespace System.Net.ServerSentEvents +{ + [EditorBrowsable(EditorBrowsableState.Never)] + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal static class Helpers + { + public static void WriteUtf8Number(this IBufferWriter writer, long value) + { +#if NET + const int MaxDecimalDigits = 20; + Span buffer = writer.GetSpan(MaxDecimalDigits); + Debug.Assert(buffer.Length >= MaxDecimalDigits); + + bool success = value.TryFormat(buffer, out int bytesWritten, provider: CultureInfo.InvariantCulture); + Debug.Assert(success); + writer.Advance(bytesWritten); +#else + writer.WriteUtf8String(value.ToString(CultureInfo.InvariantCulture).AsSpan()); +#endif + } + + public static void WriteUtf8String(this IBufferWriter writer, ReadOnlySpan value) + { + if (value.IsEmpty) + { + return; + } + + Span buffer = writer.GetSpan(value.Length); + Debug.Assert(value.Length <= buffer.Length); + value.CopyTo(buffer); + writer.Advance(value.Length); + } + + public static unsafe void WriteUtf8String(this IBufferWriter writer, ReadOnlySpan value) + { + if (value.IsEmpty) + { + return; + } + + int maxByteCount = Encoding.UTF8.GetMaxByteCount(value.Length); + Span buffer = writer.GetSpan(maxByteCount); + Debug.Assert(maxByteCount <= buffer.Length); + int bytesWritten; +#if NET + bytesWritten = Encoding.UTF8.GetBytes(value, buffer); +#else + fixed (char* chars = value) + fixed (byte* bytes = buffer) + { + bytesWritten = Encoding.UTF8.GetBytes(chars, value.Length, bytes, maxByteCount); + } +#endif + writer.Advance(bytesWritten); + } + + public static bool ContainsLineBreaks(this ReadOnlySpan text) => + text.IndexOfAny('\r', '\n') >= 0; + +#if !NET + + public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + return new ValueTask(stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken)); + } + else + { + return WriteAsyncUsingPooledBuffer(stream, buffer, cancellationToken); + + static async ValueTask WriteAsyncUsingPooledBuffer(Stream stream, ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + byte[] sharedBuffer = ArrayPool.Shared.Rent(buffer.Length); + buffer.Span.CopyTo(sharedBuffer); + try + { + await stream.WriteAsync(sharedBuffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(sharedBuffer); + } + } + } + } +#endif + + public static unsafe string Utf8GetString(ReadOnlySpan bytes) + { +#if NET + return Encoding.UTF8.GetString(bytes); +#else + fixed (byte* ptr = bytes) + { + return ptr is null ? + string.Empty : + Encoding.UTF8.GetString(ptr, bytes.Length); + } +#endif + } + } +} diff --git a/src/Shared/ServerSentEvents/PooledByteBufferWriter.cs b/src/Shared/ServerSentEvents/PooledByteBufferWriter.cs new file mode 100644 index 00000000000..0c03d4fe91a --- /dev/null +++ b/src/Shared/ServerSentEvents/PooledByteBufferWriter.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.ComponentModel; + +namespace System.Net.ServerSentEvents +{ + [EditorBrowsable(EditorBrowsableState.Never)] + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal sealed class PooledByteBufferWriter : IBufferWriter, IDisposable + { + private const int MinimumBufferSize = 256; + private ArrayBuffer _buffer = new(initialSize: 256, usePool: true); + + public void Advance(int count) => _buffer.Commit(count); + + public Memory GetMemory(int sizeHint = 0) + { + _buffer.EnsureAvailableSpace(Math.Max(sizeHint, MinimumBufferSize)); + return _buffer.AvailableMemory; + } + + public Span GetSpan(int sizeHint = 0) + { + _buffer.EnsureAvailableSpace(Math.Max(sizeHint, MinimumBufferSize)); + return _buffer.AvailableSpan; + } + + public ReadOnlyMemory WrittenMemory => _buffer.ActiveMemory; + public int Capacity => _buffer.Capacity; + public int WrittenCount => _buffer.ActiveLength; + public void Reset() => _buffer.Discard(_buffer.ActiveLength); + public void Dispose() => _buffer.Dispose(); + } +} diff --git a/src/Shared/ServerSentEvents/README.md b/src/Shared/ServerSentEvents/README.md new file mode 100644 index 00000000000..5afb7ad6627 --- /dev/null +++ b/src/Shared/ServerSentEvents/README.md @@ -0,0 +1,11 @@ +# System.Net.ServerSentEvents + +Polyfill for the System.Net.ServerSentEvents library, including the `SseFormatter` component available in .NET 10. + +To use this in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/Shared/ServerSentEvents/SseFormatter.cs b/src/Shared/ServerSentEvents/SseFormatter.cs new file mode 100644 index 00000000000..929d7672ec9 --- /dev/null +++ b/src/Shared/ServerSentEvents/SseFormatter.cs @@ -0,0 +1,169 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1405 // Debug.Assert should provide message text + +namespace System.Net.ServerSentEvents +{ + /// + /// Provides methods for formatting server-sent events. + /// + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal static class SseFormatter + { + private static readonly byte[] _newLine = "\n"u8.ToArray(); + + /// + /// Writes the of server-sent events to the stream. + /// + /// The events to write to the stream. + /// The destination stream to write the events. + /// The that can be used to cancel the write operation. + /// A task that represents the asynchronous write operation. + public static Task WriteAsync(IAsyncEnumerable> source, Stream destination, CancellationToken cancellationToken = default) + { + if (source is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(source)); + } + + if (destination is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(destination)); + } + + return WriteAsyncCore(source, destination, static (item, writer) => writer.WriteUtf8String(item.Data.AsSpan()), cancellationToken); + } + + /// + /// Writes the of server-sent events to the stream. + /// + /// The data type of the event. + /// The events to write to the stream. + /// The destination stream to write the events. + /// The formatter for the data field of given event. + /// The that can be used to cancel the write operation. + /// A task that represents the asynchronous write operation. + public static Task WriteAsync(IAsyncEnumerable> source, Stream destination, Action, IBufferWriter> itemFormatter, CancellationToken cancellationToken = default) + { + if (source is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(source)); + } + + if (destination is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(destination)); + } + + if (itemFormatter is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(itemFormatter)); + } + + return WriteAsyncCore(source, destination, itemFormatter, cancellationToken); + } + + private static async Task WriteAsyncCore(IAsyncEnumerable> source, Stream destination, Action, IBufferWriter> itemFormatter, CancellationToken cancellationToken) + { + using PooledByteBufferWriter bufferWriter = new(); + using PooledByteBufferWriter userDataBufferWriter = new(); + + await foreach (SseItem item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + itemFormatter(item, userDataBufferWriter); + + FormatSseEvent( + bufferWriter, + eventType: item._eventType, // Do not use the public property since it normalizes to "message" if null + data: userDataBufferWriter.WrittenMemory.Span, + eventId: item.EventId, + reconnectionInterval: item.ReconnectionInterval); + + await destination.WriteAsync(bufferWriter.WrittenMemory, cancellationToken).ConfigureAwait(false); + + userDataBufferWriter.Reset(); + bufferWriter.Reset(); + } + } + + private static void FormatSseEvent( + PooledByteBufferWriter bufferWriter, + string? eventType, + ReadOnlySpan data, + string? eventId, + TimeSpan? reconnectionInterval) + { + Debug.Assert(bufferWriter.WrittenCount is 0); + + if (eventType is not null) + { + Debug.Assert(!eventType.AsSpan().ContainsLineBreaks()); + + bufferWriter.WriteUtf8String("event: "u8); + bufferWriter.WriteUtf8String(eventType.AsSpan()); + bufferWriter.WriteUtf8String(_newLine); + } + + WriteLinesWithPrefix(bufferWriter, prefix: "data: "u8, data); + bufferWriter.Write(_newLine); + + if (eventId is not null) + { + Debug.Assert(!eventId.AsSpan().ContainsLineBreaks()); + + bufferWriter.WriteUtf8String("id: "u8); + bufferWriter.WriteUtf8String(eventId.AsSpan()); + bufferWriter.WriteUtf8String(_newLine); + } + + if (reconnectionInterval is { } retry) + { + Debug.Assert(retry >= TimeSpan.Zero); + + bufferWriter.WriteUtf8String("retry: "u8); + bufferWriter.WriteUtf8Number((long)retry.TotalMilliseconds); + bufferWriter.WriteUtf8String(_newLine); + } + + bufferWriter.WriteUtf8String(_newLine); + } + + private static void WriteLinesWithPrefix(PooledByteBufferWriter writer, ReadOnlySpan prefix, ReadOnlySpan data) + { + // Writes a potentially multi-line string, prefixing each line with the given prefix. + // Both \n and \r\n sequences are normalized to \n. + + while (true) + { + writer.WriteUtf8String(prefix); + + int i = data.IndexOfAny((byte)'\r', (byte)'\n'); + if (i < 0) + { + writer.WriteUtf8String(data); + return; + } + + int lineLength = i; + if (data[i++] == '\r' && i < data.Length && data[i] == '\n') + { + i++; + } + + ReadOnlySpan nextLine = data.Slice(0, lineLength); + data = data.Slice(i); + + writer.WriteUtf8String(nextLine); + writer.WriteUtf8String(_newLine); + } + } + } +} diff --git a/src/Shared/ServerSentEvents/SseItem.cs b/src/Shared/ServerSentEvents/SseItem.cs new file mode 100644 index 00000000000..9c6092fd3cf --- /dev/null +++ b/src/Shared/ServerSentEvents/SseItem.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable CA1815 // Override equals and operator equals on value types +#pragma warning disable IDE1006 // Naming Styles + +using System.ComponentModel; + +namespace System.Net.ServerSentEvents +{ + /// Represents a server-sent event. + /// Specifies the type of data payload in the event. + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal readonly struct SseItem + { + /// The event's type. + [EditorBrowsable(EditorBrowsableState.Never)] + internal readonly string? _eventType; + + /// The event's id. + private readonly string? _eventId; + + /// The event's reconnection interval. + private readonly TimeSpan? _reconnectionInterval; + + /// Initializes a new instance of the struct. + /// The event's payload. + /// The event's type. + /// Thrown when contains a line break. + public SseItem(T data, string? eventType = null) + { + if (eventType.AsSpan().ContainsLineBreaks() is true) + { + ThrowHelper.ThrowArgumentException_CannotContainLineBreaks(nameof(eventType)); + } + + Data = data; + _eventType = eventType; + } + + /// Gets the event's payload. + public T Data { get; } + + /// Gets the event's type. + public string EventType => _eventType ?? SseParser.EventTypeDefault; + + /// Gets the event's id. + /// Thrown when the value contains a line break. + public string? EventId + { + get => _eventId; + init + { + if (value.AsSpan().ContainsLineBreaks() is true) + { + ThrowHelper.ThrowArgumentException_CannotContainLineBreaks(nameof(EventId)); + } + + _eventId = value; + } + } + + /// Gets the event's retry interval. + /// + /// When specified on an event, instructs the client to update its reconnection time to the specified value. + /// + public TimeSpan? ReconnectionInterval + { + get => _reconnectionInterval; + init + { + if (value < TimeSpan.Zero) + { + ThrowHelper.ThrowArgumentException_CannotBeNegative(nameof(ReconnectionInterval)); + } + + _reconnectionInterval = value; + } + } + } +} diff --git a/src/Shared/ServerSentEvents/SseItemParser.cs b/src/Shared/ServerSentEvents/SseItemParser.cs new file mode 100644 index 00000000000..62a2bf475c3 --- /dev/null +++ b/src/Shared/ServerSentEvents/SseItemParser.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.ServerSentEvents +{ + /// Encapsulates a method for parsing the bytes payload of a server-sent event. + /// Specifies the type of the return value of the parser. + /// The event's type. + /// The event's payload bytes. + /// The parsed . + internal delegate T SseItemParser(string eventType, ReadOnlySpan data); +} diff --git a/src/Shared/ServerSentEvents/SseParser.cs b/src/Shared/ServerSentEvents/SseParser.cs new file mode 100644 index 00000000000..fbea0995eed --- /dev/null +++ b/src/Shared/ServerSentEvents/SseParser.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.Text; + +#pragma warning disable S2333 // Redundant modifiers should not be used + +namespace System.Net.ServerSentEvents +{ + /// Provides a parser for parsing server-sent events. + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal static class SseParser + { + /// The default ("message") for an event that did not explicitly specify a type. + public const string EventTypeDefault = "message"; + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// The stream containing the data to parse. + /// + /// The enumerable of strings, which can be enumerated synchronously or asynchronously. The strings + /// are decoded from the UTF8-encoded bytes of the payload of each event. + /// + /// is null. + /// + /// This overload has behavior equivalent to calling with a delegate + /// that decodes the data of each event using 's GetString method. + /// + public static SseParser Create(Stream sseStream) => + Create(sseStream, static (_, bytes) => Helpers.Utf8GetString(bytes)); + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// Specifies the type of data in each event. + /// The stream containing the data to parse. + /// The parser to use to transform each payload of bytes into a data element. + /// The enumerable, which can be enumerated synchronously or asynchronously. + /// or is null. + public static SseParser Create(Stream sseStream, SseItemParser itemParser) + { + if (sseStream is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(sseStream)); + } + + if (itemParser is null) + { + ThrowHelper.ThrowArgumentNullException(nameof(itemParser)); + } + + return new SseParser(sseStream, itemParser); + } + } +} diff --git a/src/Shared/ServerSentEvents/SseParser_1.cs b/src/Shared/ServerSentEvents/SseParser_1.cs new file mode 100644 index 00000000000..579f01d2027 --- /dev/null +++ b/src/Shared/ServerSentEvents/SseParser_1.cs @@ -0,0 +1,569 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1649 // File name should match first type name +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable SA1310 // Field names should not contain underscore +#pragma warning disable SA1203 // Constants should appear before fields +#pragma warning disable SA1514 // Element documentation header should be preceded by blank line +#pragma warning disable SA1623 // Property summary documentation should match accessors +#pragma warning disable IDE0011 // Add braces +#pragma warning disable SA1114 // Parameter list should follow declaration +#pragma warning disable SA1106 // Code should not contain empty statements +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1108 // Block statements should not contain embedded comments +#pragma warning disable format + +namespace System.Net.ServerSentEvents +{ + /// Provides a parser for server-sent events information. + /// Specifies the type of data parsed from an event. + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal sealed class SseParser + { + // For reference: + // Specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events + + /// Carriage Return. + private const byte CR = (byte)'\r'; + /// Line Feed. + private const byte LF = (byte)'\n'; + /// Carriage Return Line Feed. + private static ReadOnlySpan CRLF => "\r\n"u8; + + /// The maximum number of milliseconds representible by . + private readonly long TimeSpan_MaxValueMilliseconds = (long)TimeSpan.MaxValue.TotalMilliseconds; + + /// The default size of an ArrayPool buffer to rent. + /// Larger size used by default to minimize number of reads. Smaller size used in debug to stress growth/shifting logic. + private const int DefaultArrayPoolRentSize = +#if DEBUG + 16; +#else + 1024; +#endif + + /// The stream to be parsed. + private readonly Stream _stream; + + /// The parser delegate used to transform bytes into a . + private readonly SseItemParser _itemParser; + + /// Indicates whether the enumerable has already been used for enumeration. + private int _used; + + /// Buffer, either empty or rented, containing the data being read from the stream while looking for the next line. + private byte[] _lineBuffer = []; + /// The starting offset of valid data in . + private int _lineOffset; + /// The length of valid data in , starting from . + private int _lineLength; + /// The index in where a newline ('\r', '\n', or "\r\n") was found. + private int _newlineIndex; + /// The index in of characters already checked for newlines. + /// + /// This is to avoid O(LineLength^2) behavior in the rare case where we have long lines that are built-up over multiple reads. + /// We want to avoid re-checking the same characters we've already checked over and over again. + /// + private int _lastSearchedForNewline; + /// Set when eof has been reached in the stream. + private bool _eof; + + /// Rented buffer containing buffered data for the next event. + private byte[]? _dataBuffer; + /// The length of valid data in , starting from index 0. + private int _dataLength; + /// Whether data has been appended to . + /// This can be different than != 0 if empty data was appended. + private bool _dataAppended; + + /// The event type for the next event. + private string? _eventType; + + /// The event id for the next event. + private string? _eventId; + + /// The reconnection interval for the next event. + private TimeSpan? _nextReconnectionInterval; + + /// Initializes a new instance of the class. + /// The stream to parse. + /// The function to use to parse payload bytes into a . + internal SseParser(Stream stream, SseItemParser itemParser) + { + _stream = stream; + _itemParser = itemParser; + } + + /// Gets an enumerable of the server-sent events from this parser. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + public IEnumerable> Enumerate() + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (FillLineBuffer() != 0 && _lineLength < Utf8Bom.Length); + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // the newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // we must have CR and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = _lineOffset + _lineLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + _ = FillLineBuffer(); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets an asynchronous enumerable of the server-sent events from this parser. + /// The cancellation token to use to cancel the enumeration. + /// The parser has already been enumerated. May propagate out of a call to . + /// The enumeration was canceled. May propagate out of a call to . + public async IAsyncEnumerable> EnumerateAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (await FillLineBufferAsync(cancellationToken).ConfigureAwait(false) != 0 && _lineLength < Utf8Bom.Length) ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // newline is CR, and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = searchOffset + searchLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + _ = await FillLineBufferAsync(cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets the next index and length with which to perform a newline search. + private void GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength) + { + if (_lastSearchedForNewline > _lineOffset) + { + searchOffset = _lastSearchedForNewline; + searchLength = _lineLength - (_lastSearchedForNewline - _lineOffset); + } + else + { + searchOffset = _lineOffset; + searchLength = _lineLength; + } + + Debug.Assert(searchOffset >= _lineOffset, $"{searchOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineOffset + _lineLength, $"{searchOffset}, {_lineOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineBuffer.Length, $"{searchOffset}, {_lineBuffer.Length}"); + + Debug.Assert(searchLength >= 0, $"{searchLength}"); + Debug.Assert(searchLength <= _lineLength, $"{searchLength}, {_lineLength}"); + } + + private int GetNewLineLength() + { + Debug.Assert(_newlineIndex - _lineOffset < _lineLength, "Expected to be positioned at a non-empty newline"); + return _lineBuffer.AsSpan(_newlineIndex, _lineLength - (_newlineIndex - _lineOffset)).StartsWith(CRLF) ? 2 : 1; + } + + /// + /// If there's no room remaining in the line buffer, either shifts the contents + /// left or grows the buffer in order to make room for the next read. + /// + private void ShiftOrGrowLineBufferIfNecessary() + { + // If data we've read is butting up against the end of the buffer and + // it's not taking up the entire buffer, slide what's there down to + // the beginning, making room to read more data into the buffer (since + // there's no newline in the data that's there). Otherwise, if the whole + // buffer is full, grow the buffer to accommodate more data, since, again, + // what's there doesn't contain a newline and thus a line is longer than + // the current buffer accommodates. + if (_lineOffset + _lineLength == _lineBuffer.Length) + { + if (_lineOffset != 0) + { + _lineBuffer.AsSpan(_lineOffset, _lineLength).CopyTo(_lineBuffer); + if (_lastSearchedForNewline >= 0) + { + _lastSearchedForNewline -= _lineOffset; + } + + _lineOffset = 0; + } + else if (_lineLength == _lineBuffer.Length) + { + GrowBuffer(ref _lineBuffer, _lineBuffer.Length * 2); + } + } + } + + /// Processes a complete line from the SSE stream. + /// The parsed item if the method returns true. + /// How many characters to advance in the line buffer. + /// true if an SSE item was successfully parsed; otherwise, false. + private bool ProcessLine(out SseItem sseItem, out int advance) + { + ReadOnlySpan line = _lineBuffer.AsSpan(_lineOffset, _newlineIndex - _lineOffset); + + // Spec: "If the line is empty (a blank line) Dispatch the event" + if (line.IsEmpty) + { + advance = GetNewLineLength(); + + if (_dataAppended) + { + T data = _itemParser(_eventType ?? SseParser.EventTypeDefault, _dataBuffer.AsSpan(0, _dataLength)); + sseItem = new SseItem(data, _eventType) { EventId = _eventId, ReconnectionInterval = _nextReconnectionInterval }; + _eventType = null; + _eventId = null; + _nextReconnectionInterval = null; + _dataLength = 0; + _dataAppended = false; + return true; + } + + sseItem = default; + return false; + } + + // Find the colon separating the field name and value. + int colonPos = line.IndexOf((byte)':'); + ReadOnlySpan fieldName; + ReadOnlySpan fieldValue; + if (colonPos >= 0) + { + // Spec: "Collect the characters on the line before the first U+003A COLON character (:), and let field be that string." + fieldName = line.Slice(0, colonPos); + + // Spec: "Collect the characters on the line after the first U+003A COLON character (:), and let value be that string. + // If value starts with a U+0020 SPACE character, remove it from value." + fieldValue = line.Slice(colonPos + 1); + if (!fieldValue.IsEmpty && fieldValue[0] == (byte)' ') + { + fieldValue = fieldValue.Slice(1); + } + } + else + { + // Spec: "using the whole line as the field name, and the empty string as the field value." + fieldName = line; + fieldValue = []; + } + + if (fieldName.SequenceEqual("data"u8)) + { + // Spec: "Append the field value to the data buffer, then append a single U+000A LINE FEED (LF) character to the data buffer." + // Spec: "If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer." + + // If there's nothing currently in the data buffer and we can easily detect that this line is immediately followed by + // an empty line, we can optimize it to just handle the data directly from the line buffer, rather than first copying + // into the data buffer and dispatching from there. + if (!_dataAppended) + { + int newlineLength = GetNewLineLength(); + ReadOnlySpan remainder = _lineBuffer.AsSpan(_newlineIndex + newlineLength, _lineLength - line.Length - newlineLength); + if (!remainder.IsEmpty && + (remainder[0] is LF || (remainder[0] is CR && remainder.Length > 1))) + { + advance = line.Length + newlineLength + (remainder.StartsWith(CRLF) ? 2 : 1); + T data = _itemParser(_eventType ?? SseParser.EventTypeDefault, fieldValue); + sseItem = new SseItem(data, _eventType) { EventId = _eventId, ReconnectionInterval = _nextReconnectionInterval }; + _eventType = null; + _eventId = null; + _nextReconnectionInterval = null; + return true; + } + } + + // We need to copy the data from the data buffer to the line buffer. Make sure there's enough room. + if (_dataBuffer is null || _dataLength + _lineLength + 1 > _dataBuffer.Length) + { + GrowBuffer(ref _dataBuffer, _dataLength + _lineLength + 1); + } + + // Append a newline if there's already content in the buffer. + // Then copy the field value to the data buffer + if (_dataAppended) + { + _dataBuffer[_dataLength++] = LF; + } + + fieldValue.CopyTo(_dataBuffer.AsSpan(_dataLength)); + _dataLength += fieldValue.Length; + _dataAppended = true; + } + else if (fieldName.SequenceEqual("event"u8)) + { + // Spec: "Set the event type buffer to field value." + _eventType = Helpers.Utf8GetString(fieldValue); + } + else if (fieldName.SequenceEqual("id"u8)) + { + // Spec: "If the field value does not contain U+0000 NULL, then set the last event ID buffer to the field value. Otherwise, ignore the field." + if (fieldValue.IndexOf((byte)'\0') < 0) + { + // Note that fieldValue might be empty, in which case LastEventId will naturally be reset to the empty string. This is per spec. + LastEventId = _eventId = Helpers.Utf8GetString(fieldValue); + } + } + else if (fieldName.SequenceEqual("retry"u8)) + { + // Spec: "If the field value consists of only ASCII digits, then interpret the field value as an integer in base ten, + // and set the event stream's reconnection time to that integer. Otherwise, ignore the field." + if (long.TryParse( +#if NET + fieldValue, +#else + Helpers.Utf8GetString(fieldValue), +#endif + NumberStyles.None, CultureInfo.InvariantCulture, out long milliseconds) && + milliseconds >= 0 && milliseconds <= TimeSpan_MaxValueMilliseconds) + { + // Workaround for TimeSpan.FromMilliseconds not being able to roundtrip TimeSpan.MaxValue + TimeSpan timeSpan = milliseconds == TimeSpan_MaxValueMilliseconds ? TimeSpan.MaxValue : TimeSpan.FromMilliseconds(milliseconds); + _nextReconnectionInterval = ReconnectionInterval = timeSpan; + } + } + else + { + // We'll end up here if the line starts with a colon, producing an empty field name, or if the field name is otherwise unrecognized. + // Spec: "If the line starts with a U+003A COLON character (:) Ignore the line." + // Spec: "Otherwise, The field is ignored" + } + + advance = line.Length + GetNewLineLength(); + sseItem = default; + return false; + } + + /// Gets the last event ID. + /// This value is updated any time a new last event ID is parsed. It is not reset between SSE items. + public string LastEventId { get; private set; } = string.Empty; // Spec: "must be initialized to the empty string" + + /// Gets the reconnection interval. + /// + /// If no retry event was received, this defaults to , and it will only + /// ever be in that situation. If a client wishes to retry, the server-sent + /// events specification states that the interval may then be decided by the client implementation and should be a + /// few seconds. + /// + public TimeSpan ReconnectionInterval { get; private set; } = Timeout.InfiniteTimeSpan; + + /// Transitions the object to a used state, throwing if it's already been used. + private void ThrowIfNotFirstEnumeration() + { + if (Interlocked.Exchange(ref _used, 1) != 0) + { + ThrowHelper.ThrowInvalidOperationException_EnumerateOnlyOnce(); + } + } + + /// Reads data from the stream into the line buffer. + private int FillLineBuffer() + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = _stream.Read( +#if NET + _lineBuffer.AsSpan(offset)); +#else + _lineBuffer, offset, _lineBuffer.Length - offset); +#endif + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Reads data asynchronously from the stream into the line buffer. + private async ValueTask FillLineBufferAsync(CancellationToken cancellationToken) + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = await +#if NET + _stream.ReadAsync(_lineBuffer.AsMemory(offset), cancellationToken) +#else + new ValueTask(_stream.ReadAsync(_lineBuffer, offset, _lineBuffer.Length - offset, cancellationToken)) +#endif + .ConfigureAwait(false); + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Gets the UTF8 BOM. + private static ReadOnlySpan Utf8Bom => [0xEF, 0xBB, 0xBF]; + + /// Called at the beginning of processing to skip over an optional UTF8 byte order mark. + private void SkipBomIfPresent() + { + Debug.Assert(_lineOffset == 0, $"Expected _lineOffset == 0, got {_lineOffset}"); + + if (_lineBuffer.AsSpan(0, _lineLength).StartsWith(Utf8Bom)) + { + _lineOffset += 3; + _lineLength -= 3; + } + } + + /// Grows the buffer, returning the existing one to the ArrayPool and renting an ArrayPool replacement. + private static void GrowBuffer([NotNull] ref byte[]? buffer, int minimumLength) + { + byte[]? toReturn = buffer; + buffer = ArrayPool.Shared.Rent(Math.Max(minimumLength, DefaultArrayPoolRentSize)); + if (toReturn is not null) + { + Array.Copy(toReturn, buffer, toReturn.Length); + ArrayPool.Shared.Return(toReturn); + } + } + } +} diff --git a/src/Shared/ServerSentEvents/ThrowHelper.cs b/src/Shared/ServerSentEvents/ThrowHelper.cs new file mode 100644 index 00000000000..1ab8e8c9b21 --- /dev/null +++ b/src/Shared/ServerSentEvents/ThrowHelper.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable LA0001 // Use Microsoft.Shared.Diagnostics.Throws for improved performance + +namespace System.Net.ServerSentEvents +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + internal static class ThrowHelper + { + [DoesNotReturn] + public static void ThrowArgumentNullException(string parameterName) + { + throw new ArgumentNullException(parameterName); + } + + public static void ThrowInvalidOperationException_EnumerateOnlyOnce() + { + throw new InvalidOperationException("The enumerable may be enumerated only once."); + } + + public static void ThrowArgumentException_CannotContainLineBreaks(string parameterName) + { + throw new ArgumentException("The argument cannot contain line breaks.", parameterName); + } + + public static void ThrowArgumentException_CannotBeNegative(string parameterName) + { + throw new ArgumentException("The argument cannot be a negative value.", parameterName); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs new file mode 100644 index 00000000000..87df45b5bf3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs @@ -0,0 +1,740 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using OpenAI.Chat; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public static partial class OpenAISerializationTests +{ + [Fact] + public static async Task RequestDeserialization_SimpleMessage() + { + const string RequestJson = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":10,"temperature":0.5} + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.False(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Equal(0.5f, request.Options.Temperature); + Assert.Equal(10, request.Options.MaxOutputTokens); + Assert.Null(request.Options.TopK); + Assert.Null(request.Options.TopP); + Assert.Null(request.Options.StopSequences); + Assert.Null(request.Options.AdditionalProperties); + Assert.Null(request.Options.Tools); + + ChatMessage message = Assert.Single(request.Messages); + Assert.Equal(ChatRole.User, message.Role); + AIContent content = Assert.Single(message.Contents); + TextContent textContent = Assert.IsType(content); + Assert.Equal("hello", textContent.Text); + Assert.Null(textContent.RawRepresentation); + Assert.Null(textContent.AdditionalProperties); + } + + [Fact] + public static async Task RequestDeserialization_SimpleMessage_Stream() + { + const string RequestJson = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":20,"stream":true,"stream_options":{"include_usage":true},"temperature":0.5} + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.True(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Equal(0.5f, request.Options.Temperature); + Assert.Equal(20, request.Options.MaxOutputTokens); + Assert.Null(request.Options.TopK); + Assert.Null(request.Options.TopP); + Assert.Null(request.Options.StopSequences); + Assert.Null(request.Options.AdditionalProperties); + Assert.Null(request.Options.Tools); + + ChatMessage message = Assert.Single(request.Messages); + Assert.Equal(ChatRole.User, message.Role); + AIContent content = Assert.Single(message.Contents); + TextContent textContent = Assert.IsType(content); + Assert.Equal("hello", textContent.Text); + Assert.Null(textContent.RawRepresentation); + Assert.Null(textContent.AdditionalProperties); + } + + [Fact] + public static async Task RequestDeserialization_MultipleMessages() + { + const string RequestJson = """ + { + "messages": [ + { + "role": "system", + "content": "You are a really nice friend." + }, + { + "role": "user", + "content": "hello!" + }, + { + "role": "assistant", + "content": "hi, how are you?" + }, + { + "role": "user", + "content": "i\u0027m good. how are you?" + } + ], + "model": "gpt-4o-mini", + "frequency_penalty": 0.75, + "presence_penalty": 0.5, + "seed":42, + "stop": [ "great" ], + "temperature": 0.25, + "user": "user", + "logprobs": true, + "logit_bias": { "42" : 0 }, + "parallel_tool_calls": true, + "top_logprobs": 42, + "metadata": { "key": "value" }, + "store": true + } + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.False(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Equal(0.25f, request.Options.Temperature); + Assert.Equal(0.75f, request.Options.FrequencyPenalty); + Assert.Equal(0.5f, request.Options.PresencePenalty); + Assert.Equal(42, request.Options.Seed); + Assert.Equal(["great"], request.Options.StopSequences); + Assert.NotNull(request.Options.AdditionalProperties); + Assert.Equal("user", request.Options.AdditionalProperties["EndUserId"]); + Assert.True((bool)request.Options.AdditionalProperties["IncludeLogProbabilities"]!); + Assert.Single((IDictionary)request.Options.AdditionalProperties["LogitBiases"]!); + Assert.True((bool)request.Options.AdditionalProperties["AllowParallelToolCalls"]!); + Assert.Equal(42, request.Options.AdditionalProperties["TopLogProbabilityCount"]!); + Assert.Single((IDictionary)request.Options.AdditionalProperties["Metadata"]!); + Assert.True((bool)request.Options.AdditionalProperties["StoredOutputEnabled"]!); + + Assert.Collection(request.Messages, + msg => + { + Assert.Equal(ChatRole.System, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + TextContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("You are a really nice friend.", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }, + msg => + { + Assert.Equal(ChatRole.User, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + TextContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("hello!", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }, + msg => + { + Assert.Equal(ChatRole.Assistant, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + TextContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("hi, how are you?", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }, + msg => + { + Assert.Equal(ChatRole.User, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + TextContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("i'm good. how are you?", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }); + } + + [Fact] + public static async Task RequestDeserialization_MultiPartSystemMessage() + { + const string RequestJson = """ + { + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a really nice friend." + }, + { + "type": "text", + "text": "Really nice." + } + ] + }, + { + "role": "user", + "content": "hello!" + } + ], + "model": "gpt-4o-mini" + } + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.False(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Null(request.Options.Temperature); + Assert.Null(request.Options.FrequencyPenalty); + Assert.Null(request.Options.PresencePenalty); + Assert.Null(request.Options.Seed); + Assert.Null(request.Options.StopSequences); + + Assert.Collection(request.Messages, + msg => + { + Assert.Equal(ChatRole.System, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + Assert.Collection(msg.Contents, + content => + { + TextContent text = Assert.IsType(content); + Assert.Equal("You are a really nice friend.", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }, + content => + { + TextContent text = Assert.IsType(content); + Assert.Equal("Really nice.", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }); + }, + msg => + { + Assert.Equal(ChatRole.User, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + TextContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("hello!", text.Text); + Assert.Null(text.AdditionalProperties); + Assert.Null(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }); + } + + [Fact] + public static async Task RequestDeserialization_ToolCall() + { + const string RequestJson = """ + { + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the age of the specified person.", + "name": "GetPersonAge", + "strict": true, + "parameters": { + "type": "object", + "required": [ + "personName" + ], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + } + } + } + ], + "tool_choice": "auto" + } + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.False(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Null(request.Options.Temperature); + Assert.Null(request.Options.FrequencyPenalty); + Assert.Null(request.Options.PresencePenalty); + Assert.Null(request.Options.Seed); + Assert.Null(request.Options.StopSequences); + + Assert.Equal(ChatToolMode.Auto, request.Options.ToolMode); + Assert.NotNull(request.Options.Tools); + + AIFunction function = Assert.IsAssignableFrom(Assert.Single(request.Options.Tools)); + Assert.Equal("Gets the age of the specified person.", function.Metadata.Description); + Assert.Equal("GetPersonAge", function.Metadata.Name); + Assert.Equal("Strict", Assert.Single(function.Metadata.AdditionalProperties).Key); + Assert.Equal("Return parameter", function.Metadata.ReturnParameter.Description); + Assert.Equal("{}", Assert.IsType(function.Metadata.ReturnParameter.Schema).GetRawText()); + + AIFunctionParameterMetadata parameter = Assert.Single(function.Metadata.Parameters); + Assert.Equal("personName", parameter.Name); + Assert.True(parameter.IsRequired); + + JsonObject parameterSchema = Assert.IsType(JsonNode.Parse(Assert.IsType(parameter.Schema).GetRawText())); + Assert.Equal(2, parameterSchema.Count); + Assert.Equal("The person whose age is being requested", (string)parameterSchema["description"]!); + Assert.Equal("string", (string)parameterSchema["type"]!); + + Dictionary functionArgs = new() { ["personName"] = "John" }; + var ex = await Assert.ThrowsAsync(() => function.InvokeAsync(functionArgs)); + Assert.Contains("does not support being invoked.", ex.Message); + } + + [Fact] + public static async Task RequestDeserialization_ToolChatMessage() + { + const string RequestJson = """ + { + "messages": [ + { + "role": "assistant", + "tool_calls": [ + { + "id": "12345", + "type": "function", + "function": { + "name": "SayHello", + "arguments": "null" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "12345", + "content": "42" + } + ], + "model": "gpt-4o-mini" + } + """; + + using MemoryStream stream = new(Encoding.UTF8.GetBytes(RequestJson)); + OpenAIChatCompletionRequest request = await OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream); + + Assert.NotNull(request); + Assert.False(request.Stream); + Assert.Equal("gpt-4o-mini", request.ModelId); + + Assert.NotNull(request.Options); + Assert.Equal("gpt-4o-mini", request.Options.ModelId); + Assert.Null(request.Options.Temperature); + Assert.Null(request.Options.FrequencyPenalty); + Assert.Null(request.Options.PresencePenalty); + Assert.Null(request.Options.Seed); + Assert.Null(request.Options.StopSequences); + + Assert.Collection(request.Messages, + msg => + { + Assert.Equal(ChatRole.Assistant, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + FunctionCallContent text = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("12345", text.CallId); + Assert.Null(text.AdditionalProperties); + Assert.IsType(text.RawRepresentation); + Assert.Null(text.AdditionalProperties); + }, + msg => + { + Assert.Equal(ChatRole.Tool, msg.Role); + Assert.Null(msg.RawRepresentation); + Assert.Null(msg.AdditionalProperties); + + FunctionResultContent frc = Assert.IsType(Assert.Single(msg.Contents)); + Assert.Equal("SayHello", frc.Name); + Assert.Equal("12345", frc.CallId); + Assert.Equal(42, Assert.IsType(frc.Result).GetInt32()); + Assert.Null(frc.AdditionalProperties); + Assert.Null(frc.RawRepresentation); + Assert.Null(frc.AdditionalProperties); + }); + } + + [Fact] + public static async Task SerializeCompletion_SingleChoice() + { + ChatMessage message = new() + { + Role = ChatRole.Assistant, + Contents = [ + new TextContent("Hello! How can I assist you today?"), + new FunctionCallContent( + "callId", + "MyCoolFunc", + new Dictionary + { + ["arg1"] = 42, + ["arg2"] = "str", + }) + ] + }; + + ChatCompletion completion = new(message) + { + CompletionId = "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + ModelId = "gpt-4o-mini-2024-07-18", + CreatedAt = DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), + FinishReason = ChatFinishReason.Stop, + Usage = new() + { + InputTokenCount = 8, + OutputTokenCount = 9, + TotalTokenCount = 17, + AdditionalCounts = new() + { + { "InputTokenDetails.AudioTokenCount", 1 }, + { "InputTokenDetails.CachedTokenCount", 13 }, + { "OutputTokenDetails.AudioTokenCount", 2 }, + { "OutputTokenDetails.ReasoningTokenCount", 90 }, + } + }, + AdditionalProperties = new() + { + [nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = "fp_f85bea6784", + } + }; + + using MemoryStream stream = new(); + await OpenAISerializationHelpers.SerializeAsync(stream, completion); + string result = Encoding.UTF8.GetString(stream.ToArray()); + + AssertJsonEqual(""" + { + "id": "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Hello! How can I assist you today?", + "refusal": null, + "tool_calls": [ + { + "id": "callId", + "type": "function", + "function": { + "name": "MyCoolFunc", + "arguments": "{\r\n \u0022arg1\u0022: 42,\r\n \u0022arg2\u0022: \u0022str\u0022\r\n}" + } + } + ], + "role": "assistant" + }, + "logprobs": { + "content": [], + "refusal": [] + } + } + ], + "created": 1727888631, + "model": "gpt-4o-mini-2024-07-18", + "system_fingerprint": "fp_f85bea6784", + "object": "chat.completion", + "usage": { + "completion_tokens": 9, + "prompt_tokens": 8, + "total_tokens": 17, + "completion_tokens_details": { + "audio_tokens": 2, + "reasoning_tokens": 90 + }, + "prompt_tokens_details": { + "audio_tokens": 1, + "cached_tokens": 13 + } + } + } + """, result); + } + + [Fact] + public static async Task SerializeCompletion_ManyChoices_ThrowsNotSupportedException() + { + ChatMessage message1 = new() + { + Role = ChatRole.Assistant, + Text = "Hello! How can I assist you today?", + }; + + ChatMessage message2 = new() + { + Role = ChatRole.Assistant, + Text = "Hey there! How can I help?", + }; + + ChatCompletion completion = new([message1, message2]); + + using MemoryStream stream = new(); + var ex = await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(stream, completion)); + Assert.Contains("multiple choices", ex.Message); + } + + [Fact] + public static async Task SerializeStreamingCompletion() + { + static async IAsyncEnumerable CreateStreamingCompletion() + { + for (int i = 0; i < 5; i++) + { + List contents = [new TextContent($"Streaming update {i}")]; + + if (i == 2) + { + FunctionCallContent fcc = new( + "callId", + "MyCoolFunc", + new Dictionary + { + ["arg1"] = 42, + ["arg2"] = "str", + }); + + contents.Add(fcc); + } + + if (i == 4) + { + UsageDetails usageDetails = new() + { + InputTokenCount = 8, + OutputTokenCount = 9, + TotalTokenCount = 17, + AdditionalCounts = new() + { + { "InputTokenDetails.AudioTokenCount", 1 }, + { "InputTokenDetails.CachedTokenCount", 13 }, + { "OutputTokenDetails.AudioTokenCount", 2 }, + { "OutputTokenDetails.ReasoningTokenCount", 90 }, + } + }; + + contents.Add(new UsageContent(usageDetails)); + } + + yield return new StreamingChatCompletionUpdate + { + CompletionId = "chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl", + ModelId = "gpt-4o-mini-2024-07-18", + CreatedAt = DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), + Role = ChatRole.Assistant, + Contents = contents, + FinishReason = i == 4 ? ChatFinishReason.Stop : null, + AdditionalProperties = new() + { + [nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = "fp_f85bea6784", + }, + }; + + await Task.Yield(); + } + } + + using MemoryStream stream = new(); + await OpenAISerializationHelpers.SerializeStreamingAsync(stream, CreateStreamingCompletion()); + string result = Encoding.UTF8.GetString(stream.ToArray()); + + AssertSseEqual(""" + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","choices":[{"delta":{"content":"Streaming update 0","tool_calls":[],"role":"assistant"},"logprobs":{"content":[],"refusal":[]},"finish_reason":"stop","index":0}],"created":1727888631,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","choices":[{"delta":{"content":"Streaming update 1","tool_calls":[],"role":"assistant"},"logprobs":{"content":[],"refusal":[]},"finish_reason":"stop","index":0}],"created":1727888631,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","choices":[{"delta":{"content":"Streaming update 2","tool_calls":[{"index":0,"id":"callId","type":"function","function":{"name":"MyCoolFunc","arguments":"{\r\n \u0022arg1\u0022: 42,\r\n \u0022arg2\u0022: \u0022str\u0022\r\n}"}}],"role":"assistant"},"logprobs":{"content":[],"refusal":[]},"finish_reason":"stop","index":0}],"created":1727888631,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","choices":[{"delta":{"content":"Streaming update 3","tool_calls":[],"role":"assistant"},"logprobs":{"content":[],"refusal":[]},"finish_reason":"stop","index":0}],"created":1727888631,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","choices":[{"delta":{"content":"Streaming update 4","tool_calls":[],"role":"assistant"},"logprobs":{"content":[],"refusal":[]},"finish_reason":"stop","index":0}],"created":1727888631,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","object":"chat.completion.chunk","usage":{"completion_tokens":9,"prompt_tokens":8,"total_tokens":17,"completion_tokens_details":{"audio_tokens":2,"reasoning_tokens":90},"prompt_tokens_details":{"audio_tokens":1,"cached_tokens":13}}} + + data: [DONE] + + + """, result); + } + + [Fact] + public static async Task SerializationHelpers_NullArguments_ThrowsArgumentNullException() + { + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(null!)); + + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(null!, new(new ChatMessage()))); + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(new MemoryStream(), null!)); + + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeStreamingAsync(null!, GetStreamingChatCompletion())); + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeStreamingAsync(new MemoryStream(), null!)); + + static async IAsyncEnumerable GetStreamingChatCompletion() + { + yield return new StreamingChatCompletionUpdate(); + await Task.CompletedTask; + } + } + + [Fact] + public static async Task SerializationHelpers_HonorCancellationToken() + { + CancellationToken canceledToken = new(canceled: true); + MemoryStream stream = new("{}"u8.ToArray()); + + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.DeserializeChatCompletionRequestAsync(stream, cancellationToken: canceledToken)); + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(stream, new(new ChatMessage()), cancellationToken: canceledToken)); + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeStreamingAsync(stream, GetStreamingChatCompletion(), cancellationToken: canceledToken)); + + static async IAsyncEnumerable GetStreamingChatCompletion() + { + yield return new StreamingChatCompletionUpdate(); + await Task.CompletedTask; + } + } + + [Fact] + public static async Task SerializationHelpers_HonorJsonSerializerOptions() + { + FunctionCallContent fcc = new( + "callId", + "MyCoolFunc", + new Dictionary + { + ["arg1"] = new SomeFunctionArgument(), + }); + + ChatCompletion completion = new(new ChatMessage + { + Role = ChatRole.Assistant, + Contents = [fcc], + }); + + using MemoryStream stream = new(); + + // Passing a JSO that contains a contract for the function argument results in successful serialization. + await OpenAISerializationHelpers.SerializeAsync(stream, completion, options: JsonContextWithFunctionArgument.Default.Options); + stream.Position = 0; + + await OpenAISerializationHelpers.SerializeStreamingAsync(stream, GetStreamingCompletion(), options: JsonContextWithFunctionArgument.Default.Options); + stream.Position = 0; + + // Passing a JSO without a contract for the function argument result in failed serialization. + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(stream, completion, options: JsonContextWithoutFunctionArgument.Default.Options)); + await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeStreamingAsync(stream, GetStreamingCompletion(), options: JsonContextWithoutFunctionArgument.Default.Options)); + + async IAsyncEnumerable GetStreamingCompletion() + { + yield return new StreamingChatCompletionUpdate + { + Contents = [fcc], + }; + await Task.CompletedTask; + } + } + + private class SomeFunctionArgument; + + [JsonSerializable(typeof(SomeFunctionArgument))] + [JsonSerializable(typeof(IDictionary))] + private partial class JsonContextWithFunctionArgument : JsonSerializerContext; + + [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(IDictionary))] + private partial class JsonContextWithoutFunctionArgument : JsonSerializerContext; + + private static void AssertJsonEqual(string expected, string actual) + { + JsonNode? expectedNode = JsonNode.Parse(expected); + JsonNode? actualNode = JsonNode.Parse(actual); + + if (!JsonNode.DeepEquals(expectedNode, actualNode)) + { + // JSON documents are not equal, assert on + // normal form strings for better reporting. + expected = expectedNode?.ToJsonString() ?? "null"; + actual = actualNode?.ToJsonString() ?? "null"; + Assert.Equal(expected.NormalizeNewLines(), actual.NormalizeNewLines()); + } + } + + private static void AssertSseEqual(string expected, string actual) + { + Assert.Equal(expected.NormalizeNewLines(), actual.NormalizeNewLines()); + } + + private static string NormalizeNewLines(this string value) => + value.Replace("\r\n", "\n").Replace("\\r\\n", "\\n"); +}