diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index df417b1aaa291..dd5727d983b61 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -184,7 +184,7 @@ - + diff --git a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md index c8b78cea80db8..803541cd70e03 100644 --- a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md +++ b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md @@ -1,5 +1,35 @@ # Release History +## 2.1.0-beta.2 (2024-11-04) + +This update brings compatibility with the Azure OpenAI `2024-10-01-preview` service API version as well as the `2.1.0-beta.2` release of the `OpenAI` library. + +### Breaking Changes + +- `[Experimental]` `ChatCitation` and `ChatRetrievedDocument` have each replaced the `Uri` property of type `System.Uri` with a `string` property named `Url`. This aligns with the REST specification and accounts for the wire value of `url` not always providing a valid RFC 3986 identifier [[azure-sdk-for-net \#46793](https://github.com/Azure/azure-sdk-for-net/issues/46793)] + +### Features Added + +- The included update via `2024-09-01-preview` brings AOAI support for streaming token usage in chat completions; `Usage` is now automatically populated in `StreamingChatCompletionUpdate` instances. + - Note 1: this feature is not yet compatible when using On Your Data features (after invoking the `.AddDataSource()` extension method on `ChatCompletionOptions`) + - Note 2: this feature is not yet compatible when using image input (a `ChatMessageContentPart` of `Kind` `Image`) +- `2024-10-01-preview` further adds support for ungrounded content detection in chat completion content filter results via the `UngroundedMaterial` property on `ResponseContentFilterResult`, as retrieved from a chat completion via the `GetResponseContentFilterResult()` extension method. + +Via `OpenAI 2.0.0-beta.2`: + +- Made improvements to the experimental Realtime API. Please note this features area is currently under rapid development and not all changes may be reflected here. + - Several types have been renamed for consistency and clarity. + - ConversationRateLimitsUpdate (previously ConversationRateLimitsUpdatedUpdate) now includes named RequestDetails and TokenDetails properties, mapping to the corresponding named items in the underlying rate_limits command payload. + +### Bugs Fixed + +- Addressed an HTTP 401 issue that caused certain connection retry attempts, such as those triggered for HTTP 429 rate limiting errors, to sometimes generate a malformed request with multiple `Authorization` headers that would then be rejected. [#46401](https://github.com/Azure/azure-sdk-for-net/pull/46401) +- Addressed an issue that caused `ChatCitation` and `ChatRetrievedDocument` to sometimes throw on deserialization, specifically when a returned value in the `url` JSON field was not populated with an RFC 3986 compliant identifier for `System.Uri` [[azure-sdk-for-net \#46793](https://github.com/Azure/azure-sdk-for-net/issues/46793)] + +Via `OpenAI 2.0.0-beta.2`: + +- Fixed serialization and deserialization of ConversationToolChoice literal values (such as "required"). + ## 2.1.0-beta.1 (2024-10-01) Relative to the prior GA release, this update restores preview surfaces, retargeting to the latest `2024-08-01-preview` service `api-version` label. It also brings early support for the newly-announced `/realtime` capabilities with `gpt-4o-realtime-preview`. You can read more about Azure OpenAI support for `/realtime` in the annoucement post here: https://azure.microsoft.com/blog/announcing-new-products-and-features-for-azure-openai-service-including-gpt-4o-realtime-preview-with-audio-and-speech-capabilities/ @@ -10,7 +40,7 @@ Relative to the prior GA release, this update restores preview surfaces, retarge - This maps to the new `/realtime` beta endpoint and is thus marked with a new `[Experimental("OPENAI002")]` diagnostic tag. - This is a very early version of the convenience surface and thus subject to significant change - Documentation and samples will arrive soon; in the interim, see the scenario test files (in `/tests`) for basic usage - - You can also find an external sample employing this client, together with Azure OpenAI support, at https://github.com/Azure-Samples/aoai-realtime-audio-sdk/tree/main/dotnet/samples/console + - You can also find an external sample employing this client, together with Azure OpenAI support, at https://github.com/Azure-Samples/aoai-realtime-audio-sdk/tree/main/dotnet/samples ## 2.0.0 (2024-09-30) diff --git a/sdk/openai/Azure.AI.OpenAI/Directory.Build.props b/sdk/openai/Azure.AI.OpenAI/Directory.Build.props index 8f8f47e265927..51ec75d706d50 100644 --- a/sdk/openai/Azure.AI.OpenAI/Directory.Build.props +++ b/sdk/openai/Azure.AI.OpenAI/Directory.Build.props @@ -1,5 +1,6 @@  + 2.1.0-beta.2 true $(RequiredTargetFrameworks) true @@ -26,14 +27,8 @@ AZURE_OPENAI_GA - - - - beta.1 - - diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs index 0fff6772293d6..54838af00b039 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs @@ -56,6 +56,7 @@ public AzureOpenAIClientOptions(ServiceVersion version = LatestVersion) { #if !AZURE_OPENAI_GA ServiceVersion.V2024_08_01_Preview => "2024-08-01-preview", + ServiceVersion.V2024_09_01_Preview => "2024-09-01-preview", ServiceVersion.V2024_10_01_Preview => "2024-10-01-preview", #endif ServiceVersion.V2024_06_01 => "2024-06-01", @@ -70,6 +71,7 @@ public enum ServiceVersion V2024_06_01 = 0, #if !AZURE_OPENAI_GA V2024_08_01_Preview = 1, + V2024_09_01_Preview = 2, V2024_10_01_Preview = 3, #endif } @@ -103,7 +105,7 @@ protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount) } #if !AZURE_OPENAI_GA - private const ServiceVersion LatestVersion = ServiceVersion.V2024_08_01_Preview; + private const ServiceVersion LatestVersion = ServiceVersion.V2024_10_01_Preview; #else private const ServiceVersion LatestVersion = ServiceVersion.V2024_06_01; #endif diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs index 036b82e222794..7d5c692ab45d0 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using Azure.AI.OpenAI.Internal; using OpenAI.Chat; using System.ClientModel; using System.ClientModel.Primitives; +using System.Data.SqlTypes; using System.Diagnostics.CodeAnalysis; +#pragma warning disable AOAI001 #pragma warning disable AZC0112 namespace Azure.AI.OpenAI.Chat; @@ -63,7 +66,7 @@ public override CollectionResult CompleteChatStre /// public override AsyncCollectionResult CompleteChatStreamingAsync(IEnumerable messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default) { - PostfixClearStreamOptions(ref options); + PostfixClearStreamOptions(messages, ref options); PostfixSwapMaxTokens(ref options); return base.CompleteChatStreamingAsync(messages, options, cancellationToken); } @@ -71,25 +74,78 @@ public override AsyncCollectionResult CompleteCha /// public override CollectionResult CompleteChatStreaming(IEnumerable messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default) { - PostfixClearStreamOptions(ref options); + PostfixClearStreamOptions(messages, ref options); PostfixSwapMaxTokens(ref options); return base.CompleteChatStreaming(messages, options, cancellationToken); } - private static void PostfixClearStreamOptions(ref ChatCompletionOptions options) + /** + * As of 2024-09-01-preview, stream_options support for include_usage (which reports token usage while streaming) + * is conditionally supported: + * - When using On Your Data (non-null data_sources), stream_options is not considered valid + * - When using image input (any content part of "image" type), stream_options is not considered valid + * - Otherwise, stream_options can be defaulted to enabled per parity surface. + */ + private static void PostfixClearStreamOptions(IEnumerable messages, ref ChatCompletionOptions options) { - options ??= new(); - options.StreamOptions = null; + if (AdditionalPropertyHelpers + .GetAdditionalListProperty(options?.SerializedAdditionalRawData, "data_sources")?.Count > 0 + || messages?.Any( + message => message?.Content?.Any( + contentPart => contentPart?.Kind == ChatMessageContentPartKind.Image) == true) + == true) + { + options ??= new(); + options.StreamOptions = null; + } } + /** + * As of 2024-09-01-preview, Azure OpenAI conditionally supports the use of the new max_completion_tokens property: + * - The o1-mini and o1-preview models accept max_completion_tokens and reject max_tokens + * - All other models reject max_completion_tokens and accept max_tokens + * To handle this, each request will manipulate serialization overrides: + * - If max tokens aren't set, no action is taken + * - If serialization of max_tokens has already been blocked (e.g. via the public extension method), no + * additional logic is used and new serialization to max_completion_tokens will occur + * - Otherwise, serialization of max_completion_tokens is blocked and an override serialization of the + * corresponding max_tokens value is established + */ private static void PostfixSwapMaxTokens(ref ChatCompletionOptions options) { options ??= new(); - if (options.MaxOutputTokenCount is not null) + bool valueIsSet = options.MaxOutputTokenCount is not null; + bool oldPropertyBlocked = AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens"); + + if (valueIsSet) + { + if (!oldPropertyBlocked) + { + options.SerializedAdditionalRawData ??= new ChangeTrackingDictionary(); + AdditionalPropertyHelpers.SetEmptySentinelValue(options.SerializedAdditionalRawData, "max_completion_tokens"); + options.SerializedAdditionalRawData["max_tokens"] = BinaryData.FromObjectAsJson(options.MaxOutputTokenCount); + } + else + { + // Allow standard serialization to the new property to occur; remove overrides + if (options.SerializedAdditionalRawData.ContainsKey("max_completion_tokens")) + { + options.SerializedAdditionalRawData.Remove("max_completion_tokens"); + } + } + } + else { - options.SerializedAdditionalRawData ??= new Dictionary(); - options.SerializedAdditionalRawData["max_completion_tokens"] = BinaryData.FromObjectAsJson("__EMPTY__"); - options.SerializedAdditionalRawData["max_tokens"] = BinaryData.FromObjectAsJson(options.MaxOutputTokenCount); + if (!AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens") + && options.SerializedAdditionalRawData?.ContainsKey("max_tokens") == true) + { + options.SerializedAdditionalRawData.Remove("max_tokens"); + } + if (!AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_completion_tokens") + && options.SerializedAdditionalRawData?.ContainsKey("max_completion_tokens") == true) + { + options.SerializedAdditionalRawData.Remove("max_completion_tokens"); + } } } } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs index dc88aed8299de..e53541fe1da26 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs @@ -36,6 +36,25 @@ public static IReadOnlyList GetDataSources(this ChatCompletionOp "data_sources") as IReadOnlyList; } + [Experimental("AOAI001")] + public static void SetNewMaxCompletionTokensPropertyEnabled(this ChatCompletionOptions options, bool newPropertyEnabled = true) + { + if (newPropertyEnabled) + { + // Blocking serialization of max_tokens via dictionary acts as a signal to skip pre-serialization fixup + AdditionalPropertyHelpers.SetEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens"); + } + else + { + // In the absence of a dictionary serialization block to max_tokens, the newer property name will + // automatically be blocked and the older property name will be used via dictionary override + if (options?.SerializedAdditionalRawData?.ContainsKey("max_tokens") == true) + { + options?.SerializedAdditionalRawData?.Remove("max_tokens"); + } + } + } + [Experimental("AOAI001")] public static RequestContentFilterResult GetRequestContentFilterResult(this ChatCompletion chatCompletion) { diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs index e0dc3895e0718..2dca2d7d694c8 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs @@ -12,9 +12,6 @@ namespace Azure.AI.OpenAI.Chat; public partial class ChatCitation { - /// The location of the citation. - [CodeGenMember("Url")] - public Uri Uri { get; } /// The file path for the citation. [CodeGenMember("Filepath")] public string FilePath { get; } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs index b0703091af032..74d841304ba79 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs @@ -14,8 +14,4 @@ public partial class ChatRetrievedDocument /// The file path for the citation. [CodeGenMember("Filepath")] public string FilePath { get; } - - /// The location of the citation. - [CodeGenMember("Url")] - public Uri Uri { get; } } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs index bc0c0a10b3c7a..6086bdfdacf68 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs @@ -8,6 +8,8 @@ namespace Azure.AI.OpenAI.Internal; internal static class AdditionalPropertyHelpers { + private static string SARD_EMPTY_SENTINEL = "__EMPTY__"; + internal static T GetAdditionalProperty(IDictionary additionalProperties, string key) where T : class, IJsonModel { @@ -45,4 +47,17 @@ internal static void SetAdditionalProperty(IDictionary ad BinaryData binaryValue = BinaryData.FromStream(stream); additionalProperties[key] = binaryValue; } + + internal static void SetEmptySentinelValue(IDictionary additionalProperties, string key) + { + Argument.AssertNotNull(additionalProperties, nameof(additionalProperties)); + additionalProperties[key] = BinaryData.FromObjectAsJson(SARD_EMPTY_SENTINEL); + } + + internal static bool GetIsEmptySentinelValue(IDictionary additionalProperties, string key) + { + return additionalProperties is not null + && additionalProperties.TryGetValue(key, out BinaryData existingValue) + && StringComparer.OrdinalIgnoreCase.Equals(existingValue.ToString(), $@"""{SARD_EMPTY_SENTINEL}"""); + } } \ No newline at end of file diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs new file mode 100644 index 0000000000000..d1edc581f2768 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Diagnostics.CodeAnalysis; + +namespace Azure.AI.OpenAI; + +[Experimental("AOAI001")] +[CodeGenModel("AzureContentFilterCompletionTextSpanDetectionResult")] +public partial class ContentFilterTextSpanResult +{ } \ No newline at end of file diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs index a5504cd496951..ff66fd09eb877 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs @@ -8,3 +8,4 @@ namespace Azure.AI.OpenAI; [Experimental("AOAI001")][CodeGenModel("AzureContentFilterDetectionResult")] public partial class ContentFilterDetectionResult { } [Experimental("AOAI001")][CodeGenModel("AzureContentFilterResultForChoiceProtectedMaterialCode")] public partial class ContentFilterProtectedMaterialResult { } [Experimental("AOAI001")][CodeGenModel("AzureContentFilterSeverityResultSeverity")] public readonly partial struct ContentFilterSeverity { } +[Experimental("AOAI001")][CodeGenModel("AzureContentFilterCompletionTextSpan")] public partial class ContentFilterTextSpan { } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs index 8354486cbaed9..029a7f9ac8d96 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs @@ -52,8 +52,6 @@ private static Uri GetEndpoint(Uri endpoint, string deploymentName, string apiVe _ => uriBuilder.Scheme, }; - apiVersion = "2024-10-01-preview"; - bool isLegacyNoDeployment = string.IsNullOrEmpty(deploymentName); string requiredPathSuffix = isLegacyNoDeployment ? "realtime" : "openai/realtime"; diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs index aa3d67dbb9369..38953624b0ba5 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using Azure.Core; -using OpenAI.RealtimeConversation; using System.ClientModel.Primitives; using System.ComponentModel; using System.Net.WebSockets; @@ -14,31 +13,74 @@ internal partial class AzureRealtimeConversationSession : RealtimeConversationSe [EditorBrowsable(EditorBrowsableState.Never)] protected internal override async Task ConnectAsync(RequestOptions options) { - ClientUriBuilder uriBuilder = new(); - uriBuilder.Reset(_endpoint); + WebSocket?.Dispose(); - if (_tokenCredential is not null) + // Temporary mitigation of static RPM rate limiting behavior: pending protocol-based rate limit integration with + // rate_limits.updated, use a static retry strategy upon receiving a 429 error. + + bool shouldRetryOnFailure = false; + int rateLimitRetriesUsed = 0; + const int maximumRateLimitRetries = 3; + TimeSpan timeBetweenRateLimitRetries = TimeSpan.FromSeconds(5); + + do + { + try + { + ClientWebSocket webSocket = await CreateAzureWebSocketAsync(options).ConfigureAwait(false); + await webSocket.ConnectAsync(_endpoint, options?.CancellationToken ?? default) + .ConfigureAwait(false); + WebSocket = webSocket; + } + catch (WebSocketException webSocketException) + { + shouldRetryOnFailure + = webSocketException.Message?.Contains("429") == true && rateLimitRetriesUsed++ < maximumRateLimitRetries; + if (shouldRetryOnFailure) + { + await Task.Delay(timeBetweenRateLimitRetries).ConfigureAwait(false); + } + else + { + throw; + } + } + } while (shouldRetryOnFailure); + } + + private async Task CreateAzureWebSocketAsync(RequestOptions options) + { + string clientRequestId = Guid.NewGuid().ToString(); + + ClientWebSocket clientWebSocket = new(); + clientWebSocket.Options.AddSubProtocol("realtime"); + clientWebSocket.Options.SetRequestHeader("openai-beta", $"realtime=v1"); + clientWebSocket.Options.SetRequestHeader("x-ms-client-request-id", clientRequestId); + + try { - AccessToken token = await _tokenCredential.GetTokenAsync(_tokenRequestContext, options?.CancellationToken ?? default).ConfigureAwait(false); - _clientWebSocket.Options.SetRequestHeader("Authorization", $"Bearer {token.Token}"); + clientWebSocket.Options.SetRequestHeader("User-Agent", _userAgent); } - else + catch (ArgumentException argumentException) { - _keyCredential.Deconstruct(out string dangerousCredential); - _clientWebSocket.Options.SetRequestHeader("api-key", dangerousCredential); - // uriBuilder.AppendQuery("api-key", dangerousCredential, escape: false); + throw new PlatformNotSupportedException( + $"{nameof(RealtimeConversationClient)} is not yet supported on older .NET framework targets.", + argumentException); } - Uri endpoint = uriBuilder.ToUri(); - - try + if (_tokenCredential is not null) { - await _clientWebSocket.ConnectAsync(endpoint, options?.CancellationToken ?? default) + TokenRequestContext tokenRequestContext = new(_tokenAuthorizationScopes.ToArray(), clientRequestId); + AccessToken token = await _tokenCredential.GetTokenAsync(tokenRequestContext, options?.CancellationToken ?? default) .ConfigureAwait(false); + clientWebSocket.Options.SetRequestHeader("Authorization", $"Bearer {token.Token}"); } - catch (WebSocketException) + else { - throw; + _keyCredential.Deconstruct(out string dangerousCredential); + clientWebSocket.Options.SetRequestHeader("api-key", dangerousCredential); } + + return clientWebSocket; } } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs index 1ee9413dfe952..0fe51b40d754a 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs @@ -3,11 +3,10 @@ using System.ClientModel; using System.ClientModel.Primitives; -using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Net; +using System.Net.WebSockets; using Azure.Core; -using OpenAI.RealtimeConversation; namespace Azure.AI.OpenAI.RealtimeConversation; @@ -18,8 +17,7 @@ internal partial class AzureRealtimeConversationSession : RealtimeConversationSe private readonly ApiKeyCredential _keyCredential; private readonly TokenCredential _tokenCredential; private readonly IEnumerable _tokenAuthorizationScopes; - private readonly TokenRequestContext _tokenRequestContext; - private readonly string _clientRequestId; + private readonly string _userAgent; protected internal AzureRealtimeConversationSession( AzureRealtimeConversationClient parentClient, @@ -41,33 +39,12 @@ protected internal AzureRealtimeConversationSession( { _tokenCredential = credential; _tokenAuthorizationScopes = tokenAuthorizationScopes; - _tokenRequestContext = new(_tokenAuthorizationScopes.ToArray(), parentRequestId: _clientRequestId); } private AzureRealtimeConversationSession(AzureRealtimeConversationClient parentClient, Uri endpoint, string userAgent) : base(parentClient, endpoint, credential: new("placeholder")) { - _clientRequestId = Guid.NewGuid().ToString(); - _endpoint = endpoint; - _clientWebSocket.Options.AddSubProtocol("realtime"); - _clientWebSocket.Options.SetRequestHeader("User-Agent", userAgent); - _clientWebSocket.Options.SetRequestHeader("x-ms-client-request-id", _clientRequestId); - } - - internal override async Task SendCommandAsync(InternalRealtimeRequestCommand command, CancellationToken cancellationToken = default) - { - BinaryData requestData = ModelReaderWriter.Write(command); - - // Temporary backcompat quirk - if (command is InternalRealtimeRequestSessionUpdateCommand sessionUpdateCommand - && sessionUpdateCommand.Session?.TurnDetectionOptions is InternalRealtimeNoTurnDetection) - { - requestData = BinaryData.FromString(requestData.ToString() - .Replace(@"""turn_detection"":null", @"""turn_detection"":{""type"":""none""}")); - } - - RequestOptions cancellationOptions = cancellationToken.ToRequestOptions(); - await SendCommandAsync(requestData, cancellationOptions).ConfigureAwait(false); + _userAgent = userAgent; } } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs index d361f85eecfde..ec636d3186115 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs @@ -31,10 +31,10 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOpti writer.WritePropertyName("title"u8); writer.WriteStringValue(Title); } - if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Uri)) + if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Url)) { writer.WritePropertyName("url"u8); - writer.WriteStringValue(Uri.AbsoluteUri); + writer.WriteStringValue(Url); } if (SerializedAdditionalRawData?.ContainsKey("filepath") != true && Optional.IsDefined(FilePath)) { @@ -95,7 +95,7 @@ internal static ChatCitation DeserializeChatCitation(JsonElement element, ModelR } string content = default; string title = default; - Uri url = default; + string url = default; string filepath = default; string chunkId = default; double? rerankScore = default; @@ -115,11 +115,7 @@ internal static ChatCitation DeserializeChatCitation(JsonElement element, ModelR } if (property.NameEquals("url"u8)) { - if (property.Value.ValueKind == JsonValueKind.Null) - { - continue; - } - url = new Uri(property.Value.GetString()); + url = property.Value.GetString(); continue; } if (property.NameEquals("filepath"u8)) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs index 6d49dbb6950e9..74cd970915eab 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs @@ -54,16 +54,16 @@ internal ChatCitation(string content) /// Initializes a new instance of . /// The content of the citation. /// The title for the citation. - /// The URL of the citation. + /// The URL of the citation. /// The file path for the citation. /// The chunk ID for the citation. /// The rerank score for the retrieval. /// Keeps track of any properties unknown to the library. - internal ChatCitation(string content, string title, Uri uri, string filePath, string chunkId, double? rerankScore, IDictionary serializedAdditionalRawData) + internal ChatCitation(string content, string title, string url, string filePath, string chunkId, double? rerankScore, IDictionary serializedAdditionalRawData) { Content = content; Title = title; - Uri = uri; + Url = url; FilePath = filePath; ChunkId = chunkId; RerankScore = rerankScore; @@ -79,6 +79,8 @@ internal ChatCitation() public string Content { get; } /// The title for the citation. public string Title { get; } + /// The URL of the citation. + public string Url { get; } /// The chunk ID for the citation. public string ChunkId { get; } /// The rerank score for the retrieval. diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs index 2ab576c95b612..5f62fdce58d67 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs @@ -31,10 +31,10 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderW writer.WritePropertyName("title"u8); writer.WriteStringValue(Title); } - if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Uri)) + if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Url)) { writer.WritePropertyName("url"u8); - writer.WriteStringValue(Uri.AbsoluteUri); + writer.WriteStringValue(Url); } if (SerializedAdditionalRawData?.ContainsKey("filepath") != true && Optional.IsDefined(FilePath)) { @@ -120,7 +120,7 @@ internal static ChatRetrievedDocument DeserializeChatRetrievedDocument(JsonEleme } string content = default; string title = default; - Uri url = default; + string url = default; string filepath = default; string chunkId = default; double? rerankScore = default; @@ -144,11 +144,7 @@ internal static ChatRetrievedDocument DeserializeChatRetrievedDocument(JsonEleme } if (property.NameEquals("url"u8)) { - if (property.Value.ValueKind == JsonValueKind.Null) - { - continue; - } - url = new Uri(property.Value.GetString()); + url = property.Value.GetString(); continue; } if (property.NameEquals("filepath"u8)) diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs index 96e966ba79e03..145d0be9adf6c 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs @@ -60,7 +60,7 @@ internal ChatRetrievedDocument(string content, IEnumerable searchQueries /// Initializes a new instance of . /// The content of the citation. /// The title for the citation. - /// The URL of the citation. + /// The URL of the citation. /// The file path for the citation. /// The chunk ID for the citation. /// The rerank score for the retrieval. @@ -69,11 +69,11 @@ internal ChatRetrievedDocument(string content, IEnumerable searchQueries /// The original search score for the retrieval. /// If applicable, an indication of why the document was filtered. /// Keeps track of any properties unknown to the library. - internal ChatRetrievedDocument(string content, string title, Uri uri, string filePath, string chunkId, double? rerankScore, IReadOnlyList searchQueries, int dataSourceIndex, double? originalSearchScore, ChatDocumentFilterReason? filterReason, IDictionary serializedAdditionalRawData) + internal ChatRetrievedDocument(string content, string title, string url, string filePath, string chunkId, double? rerankScore, IReadOnlyList searchQueries, int dataSourceIndex, double? originalSearchScore, ChatDocumentFilterReason? filterReason, IDictionary serializedAdditionalRawData) { Content = content; Title = title; - Uri = uri; + Url = url; FilePath = filePath; ChunkId = chunkId; RerankScore = rerankScore; @@ -93,6 +93,8 @@ internal ChatRetrievedDocument() public string Content { get; } /// The title for the citation. public string Title { get; } + /// The URL of the citation. + public string Url { get; } /// The chunk ID for the citation. public string ChunkId { get; } /// The rerank score for the retrieval. diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs new file mode 100644 index 0000000000000..98a048d355390 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs @@ -0,0 +1,147 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace Azure.AI.OpenAI +{ + public partial class ContentFilterTextSpan : IJsonModel + { + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support writing '{format}' format."); + } + + writer.WriteStartObject(); + if (SerializedAdditionalRawData?.ContainsKey("completion_start_offset") != true) + { + writer.WritePropertyName("completion_start_offset"u8); + writer.WriteNumberValue(CompletionStartOffset); + } + if (SerializedAdditionalRawData?.ContainsKey("completion_end_offset") != true) + { + writer.WritePropertyName("completion_end_offset"u8); + writer.WriteNumberValue(CompletionEndOffset); + } + if (SerializedAdditionalRawData != null) + { + foreach (var item in SerializedAdditionalRawData) + { + if (ModelSerializationExtensions.IsSentinelValue(item.Value)) + { + continue; + } + writer.WritePropertyName(item.Key); +#if NET6_0_OR_GREATER + writer.WriteRawValue(item.Value); +#else + using (JsonDocument document = JsonDocument.Parse(item.Value)) + { + JsonSerializer.Serialize(writer, document.RootElement); + } +#endif + } + } + writer.WriteEndObject(); + } + + ContentFilterTextSpan IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support reading '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return DeserializeContentFilterTextSpan(document.RootElement, options); + } + + internal static ContentFilterTextSpan DeserializeContentFilterTextSpan(JsonElement element, ModelReaderWriterOptions options = null) + { + options ??= ModelSerializationExtensions.WireOptions; + + if (element.ValueKind == JsonValueKind.Null) + { + return null; + } + int completionStartOffset = default; + int completionEndOffset = default; + IDictionary serializedAdditionalRawData = default; + Dictionary rawDataDictionary = new Dictionary(); + foreach (var property in element.EnumerateObject()) + { + if (property.NameEquals("completion_start_offset"u8)) + { + completionStartOffset = property.Value.GetInt32(); + continue; + } + if (property.NameEquals("completion_end_offset"u8)) + { + completionEndOffset = property.Value.GetInt32(); + continue; + } + if (options.Format != "W") + { + rawDataDictionary ??= new Dictionary(); + rawDataDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); + } + } + serializedAdditionalRawData = rawDataDictionary; + return new ContentFilterTextSpan(completionStartOffset, completionEndOffset, serializedAdditionalRawData); + } + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + return ModelReaderWriter.Write(this, options); + default: + throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support writing '{options.Format}' format."); + } + } + + ContentFilterTextSpan IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return DeserializeContentFilterTextSpan(document.RootElement, options); + } + default: + throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support reading '{options.Format}' format."); + } + } + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + /// Deserializes the model from a raw response. + /// The result to deserialize the model from. + internal static ContentFilterTextSpan FromResponse(PipelineResponse response) + { + using var document = JsonDocument.Parse(response.Content); + return DeserializeContentFilterTextSpan(document.RootElement); + } + + /// Convert into a . + internal virtual BinaryContent ToBinaryContent() + { + return BinaryContent.Create(this, ModelSerializationExtensions.WireOptions); + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs new file mode 100644 index 0000000000000..2b04de8d01214 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs @@ -0,0 +1,74 @@ +// + +#nullable disable + +using System; +using System.Collections.Generic; + +namespace Azure.AI.OpenAI +{ + /// A representation of a span of completion text as used by Azure OpenAI content filter results. + public partial class ContentFilterTextSpan + { + /// + /// Keeps track of any properties unknown to the library. + /// + /// To assign an object to the value of this property use . + /// + /// + /// To assign an already formatted json string to this property use . + /// + /// + /// Examples: + /// + /// + /// BinaryData.FromObjectAsJson("foo") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromString("\"foo\"") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromObjectAsJson(new { key = "value" }) + /// Creates a payload of { "key": "value" }. + /// + /// + /// BinaryData.FromString("{\"key\": \"value\"}") + /// Creates a payload of { "key": "value" }. + /// + /// + /// + /// + internal IDictionary SerializedAdditionalRawData { get; set; } + /// Initializes a new instance of . + /// Offset of the UTF32 code point which begins the span. + /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans. + internal ContentFilterTextSpan(int completionStartOffset, int completionEndOffset) + { + CompletionStartOffset = completionStartOffset; + CompletionEndOffset = completionEndOffset; + } + + /// Initializes a new instance of . + /// Offset of the UTF32 code point which begins the span. + /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans. + /// Keeps track of any properties unknown to the library. + internal ContentFilterTextSpan(int completionStartOffset, int completionEndOffset, IDictionary serializedAdditionalRawData) + { + CompletionStartOffset = completionStartOffset; + CompletionEndOffset = completionEndOffset; + SerializedAdditionalRawData = serializedAdditionalRawData; + } + + /// Initializes a new instance of for deserialization. + internal ContentFilterTextSpan() + { + } + + /// Offset of the UTF32 code point which begins the span. + public int CompletionStartOffset { get; } + /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans. + public int CompletionEndOffset { get; } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs new file mode 100644 index 0000000000000..4e7e99ab7caf4 --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs @@ -0,0 +1,168 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace Azure.AI.OpenAI +{ + public partial class ContentFilterTextSpanResult : IJsonModel + { + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support writing '{format}' format."); + } + + writer.WriteStartObject(); + if (SerializedAdditionalRawData?.ContainsKey("filtered") != true) + { + writer.WritePropertyName("filtered"u8); + writer.WriteBooleanValue(Filtered); + } + if (SerializedAdditionalRawData?.ContainsKey("detected") != true) + { + writer.WritePropertyName("detected"u8); + writer.WriteBooleanValue(Detected); + } + if (SerializedAdditionalRawData?.ContainsKey("details") != true) + { + writer.WritePropertyName("details"u8); + writer.WriteStartArray(); + foreach (var item in Details) + { + writer.WriteObjectValue(item, options); + } + writer.WriteEndArray(); + } + if (SerializedAdditionalRawData != null) + { + foreach (var item in SerializedAdditionalRawData) + { + if (ModelSerializationExtensions.IsSentinelValue(item.Value)) + { + continue; + } + writer.WritePropertyName(item.Key); +#if NET6_0_OR_GREATER + writer.WriteRawValue(item.Value); +#else + using (JsonDocument document = JsonDocument.Parse(item.Value)) + { + JsonSerializer.Serialize(writer, document.RootElement); + } +#endif + } + } + writer.WriteEndObject(); + } + + ContentFilterTextSpanResult IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support reading '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return DeserializeContentFilterTextSpanResult(document.RootElement, options); + } + + internal static ContentFilterTextSpanResult DeserializeContentFilterTextSpanResult(JsonElement element, ModelReaderWriterOptions options = null) + { + options ??= ModelSerializationExtensions.WireOptions; + + if (element.ValueKind == JsonValueKind.Null) + { + return null; + } + bool filtered = default; + bool detected = default; + IReadOnlyList details = default; + IDictionary serializedAdditionalRawData = default; + Dictionary rawDataDictionary = new Dictionary(); + foreach (var property in element.EnumerateObject()) + { + if (property.NameEquals("filtered"u8)) + { + filtered = property.Value.GetBoolean(); + continue; + } + if (property.NameEquals("detected"u8)) + { + detected = property.Value.GetBoolean(); + continue; + } + if (property.NameEquals("details"u8)) + { + List array = new List(); + foreach (var item in property.Value.EnumerateArray()) + { + array.Add(ContentFilterTextSpan.DeserializeContentFilterTextSpan(item, options)); + } + details = array; + continue; + } + if (options.Format != "W") + { + rawDataDictionary ??= new Dictionary(); + rawDataDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); + } + } + serializedAdditionalRawData = rawDataDictionary; + return new ContentFilterTextSpanResult(filtered, detected, details, serializedAdditionalRawData); + } + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + return ModelReaderWriter.Write(this, options); + default: + throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support writing '{options.Format}' format."); + } + } + + ContentFilterTextSpanResult IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return DeserializeContentFilterTextSpanResult(document.RootElement, options); + } + default: + throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support reading '{options.Format}' format."); + } + } + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + /// Deserializes the model from a raw response. + /// The result to deserialize the model from. + internal static ContentFilterTextSpanResult FromResponse(PipelineResponse response) + { + using var document = JsonDocument.Parse(response.Content); + return DeserializeContentFilterTextSpanResult(document.RootElement); + } + + /// Convert into a . + internal virtual BinaryContent ToBinaryContent() + { + return BinaryContent.Create(this, ModelSerializationExtensions.WireOptions); + } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs new file mode 100644 index 0000000000000..9f7a3dba1e0fa --- /dev/null +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs @@ -0,0 +1,84 @@ +// + +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Azure.AI.OpenAI +{ + /// The AzureContentFilterCompletionTextSpanDetectionResult. + public partial class ContentFilterTextSpanResult + { + /// + /// Keeps track of any properties unknown to the library. + /// + /// To assign an object to the value of this property use . + /// + /// + /// To assign an already formatted json string to this property use . + /// + /// + /// Examples: + /// + /// + /// BinaryData.FromObjectAsJson("foo") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromString("\"foo\"") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromObjectAsJson(new { key = "value" }) + /// Creates a payload of { "key": "value" }. + /// + /// + /// BinaryData.FromString("{\"key\": \"value\"}") + /// Creates a payload of { "key": "value" }. + /// + /// + /// + /// + internal IDictionary SerializedAdditionalRawData { get; set; } + /// Initializes a new instance of . + /// Whether the content detection resulted in a content filtering action. + /// Whether the labeled content category was detected in the content. + /// Detailed information about the detected completion text spans. + /// is null. + internal ContentFilterTextSpanResult(bool filtered, bool detected, IEnumerable details) + { + Argument.AssertNotNull(details, nameof(details)); + + Filtered = filtered; + Detected = detected; + Details = details.ToList(); + } + + /// Initializes a new instance of . + /// Whether the content detection resulted in a content filtering action. + /// Whether the labeled content category was detected in the content. + /// Detailed information about the detected completion text spans. + /// Keeps track of any properties unknown to the library. + internal ContentFilterTextSpanResult(bool filtered, bool detected, IReadOnlyList details, IDictionary serializedAdditionalRawData) + { + Filtered = filtered; + Detected = detected; + Details = details; + SerializedAdditionalRawData = serializedAdditionalRawData; + } + + /// Initializes a new instance of for deserialization. + internal ContentFilterTextSpanResult() + { + } + + /// Whether the content detection resulted in a content filtering action. + public bool Filtered { get; } + /// Whether the labeled content category was detected in the content. + public bool Detected { get; } + /// Detailed information about the detected completion text spans. + public IReadOnlyList Details { get; } + } +} diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs index 5e9da43d62ebc..7d03d6540b2d1 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs @@ -66,6 +66,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelR writer.WritePropertyName("protected_material_code"u8); writer.WriteObjectValue(ProtectedMaterialCode, options); } + if (SerializedAdditionalRawData?.ContainsKey("ungrounded_material") != true && Optional.IsDefined(UngroundedMaterial)) + { + writer.WritePropertyName("ungrounded_material"u8); + writer.WriteObjectValue(UngroundedMaterial, options); + } if (SerializedAdditionalRawData != null) { foreach (var item in SerializedAdditionalRawData) @@ -117,6 +122,7 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu InternalAzureContentFilterResultForPromptContentFilterResultsError error = default; ContentFilterDetectionResult protectedMaterialText = default; ContentFilterProtectedMaterialResult protectedMaterialCode = default; + ContentFilterTextSpanResult ungroundedMaterial = default; IDictionary serializedAdditionalRawData = default; Dictionary rawDataDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -202,6 +208,15 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu protectedMaterialCode = ContentFilterProtectedMaterialResult.DeserializeContentFilterProtectedMaterialResult(property.Value, options); continue; } + if (property.NameEquals("ungrounded_material"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + ungroundedMaterial = ContentFilterTextSpanResult.DeserializeContentFilterTextSpanResult(property.Value, options); + continue; + } if (options.Format != "W") { rawDataDictionary ??= new Dictionary(); @@ -219,6 +234,7 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu error, protectedMaterialText, protectedMaterialCode, + ungroundedMaterial, serializedAdditionalRawData); } diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs index be77b4567b3d0..87fda89bae54b 100644 --- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs +++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs @@ -75,8 +75,9 @@ internal ResponseContentFilterResult() /// If present, details about an error that prevented content filtering from completing its evaluation. /// A detection result that describes a match against text protected under copyright or other status. /// A detection result that describes a match against licensed code or other protected source material. + /// /// Keeps track of any properties unknown to the library. - internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, ContentFilterSeverityResult hate, ContentFilterSeverityResult violence, ContentFilterSeverityResult selfHarm, ContentFilterDetectionResult profanity, ContentFilterBlocklistResult customBlocklists, InternalAzureContentFilterResultForPromptContentFilterResultsError error, ContentFilterDetectionResult protectedMaterialText, ContentFilterProtectedMaterialResult protectedMaterialCode, IDictionary serializedAdditionalRawData) + internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, ContentFilterSeverityResult hate, ContentFilterSeverityResult violence, ContentFilterSeverityResult selfHarm, ContentFilterDetectionResult profanity, ContentFilterBlocklistResult customBlocklists, InternalAzureContentFilterResultForPromptContentFilterResultsError error, ContentFilterDetectionResult protectedMaterialText, ContentFilterProtectedMaterialResult protectedMaterialCode, ContentFilterTextSpanResult ungroundedMaterial, IDictionary serializedAdditionalRawData) { Sexual = sexual; Hate = hate; @@ -87,6 +88,7 @@ internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, Content Error = error; ProtectedMaterialText = protectedMaterialText; ProtectedMaterialCode = protectedMaterialCode; + UngroundedMaterial = ungroundedMaterial; SerializedAdditionalRawData = serializedAdditionalRawData; } @@ -125,5 +127,7 @@ internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, Content public ContentFilterDetectionResult ProtectedMaterialText { get; } /// A detection result that describes a match against licensed code or other protected source material. public ContentFilterProtectedMaterialResult ProtectedMaterialCode { get; } + /// Gets the ungrounded material. + public ContentFilterTextSpanResult UngroundedMaterial { get; } } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs index 86bde5605ab71..d5957f918e3ba 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs @@ -87,6 +87,7 @@ public async Task ChatWithImagesStreaming(bool useUri) { bool foundPromptFilter = false; bool foundResponseFilter = false; + ChatTokenUsage? usage = null; StringBuilder content = new(); ChatClient client = GetTestClient("vision"); @@ -123,9 +124,11 @@ public async Task ChatWithImagesStreaming(bool useUri) await foreach (StreamingChatCompletionUpdate update in response) { - ValidateUpdate(update, content, ref foundPromptFilter, ref foundResponseFilter); + ValidateUpdate(update, content, ref foundPromptFilter, ref foundResponseFilter, ref usage); } + // Assert.That(usage, Is.Not.Null); + // TODO FIXME: gpt-4o models seem to return inconsistent prompt filters to skip this for now //Assert.That(foundPromptFilter, Is.True); diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs index 096cfc602eede..83105dd0dddb9 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs @@ -142,6 +142,69 @@ public void DataSourceSerializationWorks() Assert.That(sourcesFromOptions[1], Is.InstanceOf()); } +#if !AZURE_OPENAI_GA + [Test] + [Category("Smoke")] + public async Task MaxTokensSerializationConfigurationWorks() + { + using MockHttpMessageHandler pipeline = new(MockHttpMessageHandler.ReturnEmptyJson); + + Uri endpoint = new Uri("https://www.bing.com/"); + string apiKey = "not-a-real-one"; + string model = "ignore"; + + AzureOpenAIClient topLevel = new( + endpoint, + new ApiKeyCredential(apiKey), + new AzureOpenAIClientOptions() + { + Transport = pipeline.Transport + }); + + ChatClient client = topLevel.GetChatClient(model); + + ChatCompletionOptions options = new(); + bool GetSerializedOptionsContains(string value) + { + BinaryData serialized = ModelReaderWriter.Write(options); + return serialized.ToString().Contains(value); + } + async Task AssertExpectedSerializationAsync(bool hasOldMaxTokens, bool hasNewMaxCompletionTokens) + { + _ = await client.CompleteChatAsync(["Just mocking, no call here"], options); + Assert.That(GetSerializedOptionsContains("max_tokens"), Is.EqualTo(hasOldMaxTokens)); + Assert.That(GetSerializedOptionsContains("max_completion_tokens"), Is.EqualTo(hasNewMaxCompletionTokens)); + } + + await AssertExpectedSerializationAsync(false, false); + await AssertExpectedSerializationAsync(false, false); + + options.MaxOutputTokenCount = 42; + await AssertExpectedSerializationAsync(true, false); + await AssertExpectedSerializationAsync(true, false); + options.MaxOutputTokenCount = null; + await AssertExpectedSerializationAsync(false, false); + options.MaxOutputTokenCount = 42; + await AssertExpectedSerializationAsync(true, false); + + options.SetNewMaxCompletionTokensPropertyEnabled(); + await AssertExpectedSerializationAsync(false, true); + await AssertExpectedSerializationAsync(false, true); + options.MaxOutputTokenCount = null; + await AssertExpectedSerializationAsync(false, false); + options.MaxOutputTokenCount = 42; + await AssertExpectedSerializationAsync(false, true); + + options.SetNewMaxCompletionTokensPropertyEnabled(false); + await AssertExpectedSerializationAsync(true, false); + await AssertExpectedSerializationAsync(true, false); + options.MaxOutputTokenCount = null; + await AssertExpectedSerializationAsync(false, false); + options.MaxOutputTokenCount = 42; + await AssertExpectedSerializationAsync(true, false); + } +#endif + [RecordedTest] public async Task ChatCompletionBadKeyGivesHelpfulError() { @@ -162,7 +225,6 @@ public async Task ChatCompletionBadKeyGivesHelpfulError() } [RecordedTest] - [Category("Smoke")] public async Task DefaultAzureCredentialWorks() { ChatClient chatClient = GetTestClient(tokenCredential: this.TestEnvironment.Credential); @@ -492,6 +554,7 @@ public async Task ChatCompletionStreaming() StringBuilder builder = new(); bool foundPromptFilter = false; bool foundResponseFilter = false; + ChatTokenUsage? usage = null; ChatClient chatClient = GetTestClient(); @@ -512,12 +575,14 @@ public async Task ChatCompletionStreaming() await foreach (StreamingChatCompletionUpdate update in streamingResults) { - ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter); + ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter, ref usage); } string allText = builder.ToString(); Assert.That(allText, Is.Not.Null.Or.Empty); + Assert.That(usage, Is.Not.Null); + Assert.That(foundPromptFilter, Is.True); Assert.That(foundResponseFilter, Is.True); } @@ -528,6 +593,7 @@ public async Task SearchExtensionWorksStreaming() StringBuilder builder = new(); bool foundPromptFilter = false; bool foundResponseFilter = false; + ChatTokenUsage? usage = null; List contexts = new(); var searchConfig = TestConfig.GetConfig("search")!; @@ -555,7 +621,7 @@ public async Task SearchExtensionWorksStreaming() await foreach (StreamingChatCompletionUpdate update in chatUpdates) { - ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter); + ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter, ref usage); ChatMessageContext context = update.GetMessageContext(); if (context != null) @@ -567,6 +633,8 @@ public async Task SearchExtensionWorksStreaming() string allText = builder.ToString(); Assert.That(allText, Is.Not.Null.Or.Empty); + // Assert.That(usage, Is.Not.Null); + // TODO FIXME: When using data sources, the service does not appear to return request nor response filtering information //Assert.That(foundPromptFilter, Is.True); //Assert.That(foundResponseFilter, Is.True); @@ -636,7 +704,7 @@ in client.CompleteChatStreamingAsync( #endregion #region Helper methods - private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder builder, ref bool foundPromptFilter, ref bool foundResponseFilter) + private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder builder, ref bool foundPromptFilter, ref bool foundResponseFilter, ref ChatTokenUsage? usage) { if (update.CreatedAt == UNIX_EPOCH) { @@ -656,6 +724,8 @@ private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder Assert.That(update.FinishReason, Is.Null.Or.EqualTo(ChatFinishReason.Stop)); if (update.Usage != null) { + Assert.That(usage, Is.Null); + usage = update.Usage; Assert.That(update.Usage.InputTokenCount, Is.GreaterThanOrEqualTo(0)); Assert.That(update.Usage.OutputTokenCount, Is.GreaterThanOrEqualTo(0)); Assert.That(update.Usage.TotalTokenCount, Is.GreaterThanOrEqualTo(0)); diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs index db54c9583c1c0..a3a3194357134 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs @@ -21,7 +21,7 @@ public void ItemCreation() { ConversationItem messageItem = ConversationItem.CreateUserMessage(["Hello, world!"]); Assert.That(messageItem?.MessageContentParts?.Count, Is.EqualTo(1)); - Assert.That(messageItem.MessageContentParts[0].TextValue, Is.EqualTo("Hello, world!")); + Assert.That(messageItem.MessageContentParts[0].Text, Is.EqualTo("Hello, world!")); } [Test] @@ -36,8 +36,7 @@ public void OptionsSerializationWorks() Model = "whisper-1", }, Instructions = "test instructions", - MaxResponseOutputTokens = 42, - Model = "gpt-4o-realtime-preview", + MaxOutputTokens = 42, OutputAudioFormat = ConversationAudioFormat.G711Ulaw, Temperature = 0.42f, ToolChoice = ConversationToolChoice.CreateFunctionToolChoice("test-function"), @@ -67,7 +66,6 @@ public void OptionsSerializationWorks() Assert.That(jsonNode["input_audio_transcription"]?["model"]?.GetValue(), Is.EqualTo("whisper-1")); Assert.That(jsonNode["instructions"]?.GetValue(), Is.EqualTo("test instructions")); Assert.That(jsonNode["max_response_output_tokens"]?.GetValue(), Is.EqualTo(42)); - Assert.That(jsonNode["model"]?.GetValue(), Is.EqualTo("gpt-4o-realtime-preview")); Assert.That(jsonNode["output_audio_format"]?.GetValue(), Is.EqualTo("g711_ulaw")); Assert.That(jsonNode["temperature"]?.GetValue(), Is.EqualTo(0.42f)); Assert.That(jsonNode["tools"]?.AsArray()?.ToList(), Has.Count.EqualTo(1)); @@ -85,8 +83,7 @@ public void OptionsSerializationWorks() Assert.That(deserializedOptions.InputAudioFormat, Is.EqualTo(ConversationAudioFormat.G711Alaw)); Assert.That(deserializedOptions.InputTranscriptionOptions?.Model, Is.EqualTo(ConversationTranscriptionModel.Whisper1)); Assert.That(deserializedOptions.Instructions, Is.EqualTo("test instructions")); - Assert.That(deserializedOptions.MaxResponseOutputTokens.NumericValue, Is.EqualTo(42)); - Assert.That(deserializedOptions.Model, Is.EqualTo("gpt-4o-realtime-preview")); + Assert.That(deserializedOptions.MaxOutputTokens.NumericValue, Is.EqualTo(42)); Assert.That(deserializedOptions.OutputAudioFormat, Is.EqualTo(ConversationAudioFormat.G711Ulaw)); Assert.That(deserializedOptions.Tools, Has.Count.EqualTo(1)); Assert.That(deserializedOptions.Tools[0].Kind, Is.EqualTo(ConversationToolKind.Function)); @@ -97,5 +94,92 @@ public void OptionsSerializationWorks() Assert.That(deserializedOptions.ToolChoice?.FunctionName, Is.EqualTo("test-function")); Assert.That(deserializedOptions.TurnDetectionOptions?.Kind, Is.EqualTo(ConversationTurnDetectionKind.ServerVoiceActivityDetection)); Assert.That(deserializedOptions.Voice, Is.EqualTo(ConversationVoice.Echo)); + + ConversationSessionOptions emptyOptions = new(); + Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.False); + Assert.That(ModelReaderWriter.Write(emptyOptions).ToString(), Does.Not.Contain("modal")); + emptyOptions.ContentModalities |= ConversationContentModalities.Audio; + Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.True); + Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Text), Is.False); + Assert.That(ModelReaderWriter.Write(emptyOptions).ToString(), Does.Contain("modal")); + } + + [Test] + public void MaxTokensSerializationWorks() + { + // Implicit omission + ConversationSessionOptions options = new() { }; + BinaryData serializedOptions = ModelReaderWriter.Write(options); + Assert.That(serializedOptions.ToString(), Does.Not.Contain("max_response_output_tokens")); + + // Explicit omission + options = new() + { + MaxOutputTokens = null + }; + serializedOptions = ModelReaderWriter.Write(options); + Assert.That(serializedOptions.ToString(), Does.Not.Contain("max_response_output_tokens")); + + // Explicit default (null) + options = new() + { + MaxOutputTokens = ConversationMaxTokensChoice.CreateDefaultMaxTokensChoice() + }; + serializedOptions = ModelReaderWriter.Write(options); + Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":null")); + + // Numeric literal + options = new() + { + MaxOutputTokens = 42, + }; + serializedOptions = ModelReaderWriter.Write(options); + Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":42")); + + // Numeric by factory + options = new() + { + MaxOutputTokens = ConversationMaxTokensChoice.CreateNumericMaxTokensChoice(42) + }; + serializedOptions = ModelReaderWriter.Write(options); + Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":42")); + } + + [Test] + public void TurnDetectionSerializationWorks() + { + // Implicit omission + ConversationSessionOptions sessionOptions = new(); + BinaryData serializedOptions = ModelReaderWriter.Write(sessionOptions); + Assert.That(serializedOptions.ToString(), Does.Not.Contain("turn_detection")); + + sessionOptions = new() + { + TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(), + }; + serializedOptions = ModelReaderWriter.Write(sessionOptions); + Assert.That(serializedOptions.ToString(), Does.Contain(@"""turn_detection"":null")); + + sessionOptions = new() + { + TurnDetectionOptions = ConversationTurnDetectionOptions.CreateServerVoiceActivityTurnDetectionOptions( + detectionThreshold: 0.42f) + }; + serializedOptions = ModelReaderWriter.Write(sessionOptions); + JsonNode serializedNode = JsonNode.Parse(serializedOptions); + Assert.That(serializedNode["turn_detection"]?["type"]?.GetValue(), Is.EqualTo("server_vad")); + Assert.That(serializedNode["turn_detection"]?["threshold"]?.GetValue(), Is.EqualTo(0.42f)); + } + + [Test] + public void UnknownCommandSerializationWorks() + { + BinaryData serializedUnknownCommand = BinaryData.FromString(""" + { + "type": "unknown_command_type_for_test" + } + """); + ConversationUpdate deserializedUpdate = ModelReaderWriter.Read(serializedUnknownCommand); + Assert.That(deserializedUpdate, Is.Not.Null); } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs index 5d6624e2af9b4..d718cc1eb5811 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs @@ -30,7 +30,7 @@ public ConversationTestFixtureBase(bool isAsync) CancellationTokenSource = new(); if (!Debugger.IsAttached) { - CancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(15)); + CancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(25)); } DefaultConfiguration = TestConfig.GetConfig("rt_eus2"); if (DefaultConfiguration is null || DefaultConfiguration.Endpoint is null) diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs index 2d3e47214466b..89cbeda0c6865 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using OpenAI; using OpenAI.RealtimeConversation; using System; using System.ClientModel.Primitives; @@ -26,29 +27,51 @@ public async Task CanConfigureSession() RealtimeConversationClient client = GetTestClient(); using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken); - await session.ConfigureSessionAsync( - new ConversationSessionOptions() - { - Instructions = "You are a helpful assistant.", - TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(), - OutputAudioFormat = ConversationAudioFormat.G711Ulaw - }, - CancellationToken); + ConversationSessionOptions sessionOptions = new() + { + Instructions = "You are a helpful assistant.", + TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(), + OutputAudioFormat = ConversationAudioFormat.G711Ulaw, + MaxOutputTokens = 2048, + }; - await session.StartResponseTurnAsync(CancellationToken); + await session.ConfigureSessionAsync(sessionOptions, CancellationToken); + ConversationSessionOptions responseOverrideOptions = new() + { + ContentModalities = ConversationContentModalities.Text, + }; + if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient))) + { + responseOverrideOptions.MaxOutputTokens = ConversationMaxTokensChoice.CreateInfiniteMaxTokensChoice(); + } + await session.AddItemAsync( + ConversationItem.CreateUserMessage(["Hello, assistant! Tell me a joke."]), + CancellationToken); + await session.StartResponseAsync(responseOverrideOptions, CancellationToken); List receivedUpdates = []; await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken)) { receivedUpdates.Add(update); - + if (update is ConversationErrorUpdate errorUpdate) { Assert.That(errorUpdate.Kind, Is.EqualTo(ConversationUpdateKind.Error)); Assert.Fail($"Error: {ModelReaderWriter.Write(errorUpdate)}"); } - else if (update is ConversationResponseFinishedUpdate) + else if ((update is ConversationItemStreamingPartDeltaUpdate deltaUpdate && deltaUpdate.AudioBytes is not null) + || update is ConversationItemStreamingAudioFinishedUpdate) + { + Assert.Fail($"Audio content streaming unexpected after configuring response-level text-only modalities"); + } + else if (update is ConversationSessionConfiguredUpdate sessionConfiguredUpdate) + { + Assert.That(sessionConfiguredUpdate.OutputAudioFormat == sessionOptions.OutputAudioFormat); + Assert.That(sessionConfiguredUpdate.TurnDetectionOptions.Kind, Is.EqualTo(ConversationTurnDetectionKind.Disabled)); + Assert.That(sessionConfiguredUpdate.MaxOutputTokens.NumericValue, Is.EqualTo(sessionOptions.MaxOutputTokens.NumericValue)); + } + else if (update is ConversationResponseFinishedUpdate turnFinishedUpdate) { break; } @@ -62,8 +85,8 @@ List GetReceivedUpdates() where T : ConversationUpdate Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); - Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); - Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); + Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); + Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1)); } [Test] @@ -74,9 +97,11 @@ public async Task TextOnlyWorks() await session.AddItemAsync( ConversationItem.CreateUserMessage(["Hello, world!"]), cancellationToken: CancellationToken); - await session.StartResponseTurnAsync(CancellationToken); + await session.StartResponseAsync(CancellationToken); StringBuilder responseBuilder = new(); + bool gotResponseDone = false; + bool gotRateLimits = false; await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken)) { @@ -84,23 +109,134 @@ await session.AddItemAsync( { Assert.That(sessionStartedUpdate.SessionId, Is.Not.Null.And.Not.Empty); } - if (update is ConversationTextDeltaUpdate textDeltaUpdate) + if (update is ConversationItemStreamingPartDeltaUpdate deltaUpdate) { - responseBuilder.Append(textDeltaUpdate.Delta); + responseBuilder.Append(deltaUpdate.AudioTranscript); } - if (update is ConversationItemAcknowledgedUpdate itemAddedUpdate) + if (update is ConversationItemCreatedUpdate itemCreatedUpdate) { - Assert.That(itemAddedUpdate.Item is not null); + if (itemCreatedUpdate.MessageRole == ConversationMessageRole.Assistant) + { + // The assistant-created item should be streamed and should not have content yet when acknowledged + Assert.That(itemCreatedUpdate.MessageContentParts, Has.Count.EqualTo(0)); + } + else if (itemCreatedUpdate.MessageRole == ConversationMessageRole.User) + { + // When acknowledging an item added by the client (user), the text should already be there + Assert.That(itemCreatedUpdate.MessageContentParts, Has.Count.EqualTo(1)); + Assert.That(itemCreatedUpdate.MessageContentParts[0].Text, Is.EqualTo("Hello, world!")); + } + else + { + Assert.Fail($"Test didn't expect an acknowledged item with role: {itemCreatedUpdate.MessageRole}"); + } } - if (update is ConversationResponseFinishedUpdate) + if (update is ConversationResponseFinishedUpdate responseFinishedUpdate) { + Assert.That(responseFinishedUpdate.CreatedItems, Has.Count.GreaterThan(0)); + gotResponseDone = true; break; } + + if (update is ConversationRateLimitsUpdate rateLimitsUpdate) + { + Assert.That(rateLimitsUpdate.AllDetails, Has.Count.EqualTo(2)); + Assert.That(rateLimitsUpdate.TokenDetails, Is.Not.Null); + Assert.That(rateLimitsUpdate.TokenDetails.Name, Is.EqualTo("tokens")); + Assert.That(rateLimitsUpdate.TokenDetails.MaximumCount, Is.GreaterThan(0)); + Assert.That(rateLimitsUpdate.TokenDetails.RemainingCount, Is.GreaterThan(0)); + Assert.That(rateLimitsUpdate.TokenDetails.RemainingCount, Is.LessThan(rateLimitsUpdate.TokenDetails.MaximumCount)); + Assert.That(rateLimitsUpdate.TokenDetails.TimeUntilReset, Is.GreaterThan(TimeSpan.Zero)); + Assert.That(rateLimitsUpdate.RequestDetails, Is.Not.Null); + gotRateLimits = true; + } } Assert.That(responseBuilder.ToString(), Is.Not.Null.Or.Empty); + Assert.That(gotResponseDone, Is.True); + + if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient))) + { + // Temporarily assume that subclients don't support rate limit commands + Assert.That(gotRateLimits, Is.True); + } + } + + [Test] + public async Task ItemManipulationWorks() + { + RealtimeConversationClient client = GetTestClient(); + using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken); + + await session.ConfigureSessionAsync( + new ConversationSessionOptions() + { + TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(), + ContentModalities = ConversationContentModalities.Text, + }, + CancellationToken); + + await session.AddItemAsync( + ConversationItem.CreateUserMessage(["The first special word you know about is 'aardvark'."]), + CancellationToken); + await session.AddItemAsync( + ConversationItem.CreateUserMessage(["The next special word you know about is 'banana'."]), + CancellationToken); + await session.AddItemAsync( + ConversationItem.CreateUserMessage(["The next special word you know about is 'coconut'."]), + CancellationToken); + + bool gotSessionStarted = false; + bool gotSessionConfigured = false; + bool gotResponseFinished = false; + + await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken)) + { + if (update is ConversationSessionStartedUpdate) + { + gotSessionStarted = true; + } + + if (update is ConversationSessionConfiguredUpdate sessionConfiguredUpdate) + { + Assert.That(sessionConfiguredUpdate.TurnDetectionOptions.Kind, Is.EqualTo(ConversationTurnDetectionKind.Disabled)); + Assert.That(sessionConfiguredUpdate.ContentModalities.HasFlag(ConversationContentModalities.Text), Is.True); + Assert.That(sessionConfiguredUpdate.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.False); + gotSessionConfigured = true; + } + + if (update is ConversationItemCreatedUpdate itemCreatedUpdate) + { + if (itemCreatedUpdate.MessageContentParts.Count > 0 + && itemCreatedUpdate.MessageContentParts[0].Text.Contains("banana")) + { + await session.DeleteItemAsync(itemCreatedUpdate.ItemId, CancellationToken); + await session.AddItemAsync( + ConversationItem.CreateUserMessage(["What's the second special word you know about?"]), + CancellationToken); + await session.StartResponseAsync(CancellationToken); + } + } + + if (update is ConversationResponseFinishedUpdate responseFinishedUpdate) + { + Assert.That(responseFinishedUpdate.CreatedItems.Count, Is.EqualTo(1)); + Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts.Count, Is.EqualTo(1)); + Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts[0].Text, Does.Contain("coconut")); + Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts[0].Text, Does.Not.Contain("banana")); + gotResponseFinished = true; + break; + } + } + + Assert.That(gotSessionStarted, Is.True); + if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient))) + { + Assert.That(gotSessionConfigured, Is.True); + } + Assert.That(gotResponseFinished, Is.True); } [Test] @@ -150,21 +286,15 @@ public async Task AudioWithToolsWorks() await session.ConfigureSessionAsync(options, CancellationToken); - const string folderName = "Assets"; - const string fileName = "whats_the_weather_pcm16_24khz_mono.wav"; -#if NET6_0_OR_GREATER - using Stream audioStream = File.OpenRead(Path.Join(folderName, fileName)); -#else - using Stream audioStream = File.OpenRead($"{folderName}\\{fileName}"); -#endif - _ = session.SendAudioAsync(audioStream, CancellationToken); + string audioFilePath = Directory.EnumerateFiles("Assets") + .First(path => path.Contains("whats_the_weather_pcm16_24khz_mono.wav")); + using Stream audioStream = File.OpenRead(audioFilePath); + _ = session.SendInputAudioAsync(audioStream, CancellationToken); string userTranscript = null; await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken)) { - Assert.That(update.EventId, Is.Not.Null.And.Not.Empty); - if (update is ConversationSessionStartedUpdate sessionStartedUpdate) { Assert.That(sessionStartedUpdate.SessionId, Is.Not.Null.And.Not.Empty); @@ -175,12 +305,12 @@ public async Task AudioWithToolsWorks() Assert.That(sessionStartedUpdate.Temperature, Is.GreaterThan(0)); } - if (update is ConversationInputTranscriptionFinishedUpdate inputTranscriptionFinishedUpdate) + if (update is ConversationInputTranscriptionFinishedUpdate inputTranscriptionCompletedUpdate) { - userTranscript = inputTranscriptionFinishedUpdate.Transcript; + userTranscript = inputTranscriptionCompletedUpdate.Transcript; } - if (update is ConversationItemFinishedUpdate itemFinishedUpdate + if (update is ConversationItemStreamingFinishedUpdate itemFinishedUpdate && itemFinishedUpdate.FunctionCallId is not null) { Assert.That(itemFinishedUpdate.FunctionName, Is.EqualTo(getWeatherTool.Name)); @@ -195,7 +325,7 @@ public async Task AudioWithToolsWorks() { if (turnFinishedUpdate.CreatedItems.Any(item => !string.IsNullOrEmpty(item.FunctionCallId))) { - await session.StartResponseTurnAsync(CancellationToken); + await session.StartResponseAsync(CancellationToken); } else { @@ -227,7 +357,7 @@ await session.ConfigureSessionAsync( #else using Stream audioStream = File.OpenRead($"{folderName}\\{fileName}"); #endif - await session.SendAudioAsync(audioStream, CancellationToken); + await session.SendInputAudioAsync(audioStream, CancellationToken); await session.AddItemAsync(ConversationItem.CreateUserMessage(["Hello, assistant!"]), CancellationToken); @@ -248,11 +378,41 @@ or ConversationResponseStartedUpdate Assert.Fail($"Shouldn't receive any VAD events or response creation!"); } - if (update is ConversationItemAcknowledgedUpdate itemAcknowledgedUpdate - && itemAcknowledgedUpdate.Item.MessageRole == ConversationMessageRole.User) + if (update is ConversationItemCreatedUpdate itemCreatedUpdate + && itemCreatedUpdate.MessageRole == ConversationMessageRole.User) { break; } } } + + [Test] + public async Task BadCommandProvidesError() + { + RealtimeConversationClient client = GetTestClient(); + using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken); + + await session.SendCommandAsync( + BinaryData.FromString(""" + { + "type": "update_conversation_config2", + "event_id": "event_fabricated_1234abcd" + } + """), + CancellationOptions); + + bool gotErrorUpdate = false; + + await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken)) + { + if (update is ConversationErrorUpdate errorUpdate) + { + Assert.That(errorUpdate.ErrorEventId, Is.EqualTo("event_fabricated_1234abcd")); + gotErrorUpdate = true; + break; + } + } + + Assert.That(gotErrorUpdate, Is.True); + } } diff --git a/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs index 73f5ba7d86058..ce12b8f82498e 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs @@ -24,10 +24,16 @@ namespace Azure.AI.OpenAI.Tests; +[Category("FineTuning")] public class FineTuningTests : AoaiTestBase { public FineTuningTests(bool isAsync) : base(isAsync) - { } + { + if (Mode == RecordedTestMode.Playback) + { + Assert.Inconclusive("Playback for fine-tuning temporarily disabled"); + } + } #if !AZURE_OPENAI_GA [Test] @@ -57,7 +63,6 @@ public async Task JobsFineTuning() } [RecordedTest] - [Ignore("Disable pending resolution of test framework interception overriding Rehydrateoverrides")] public async Task CheckpointsFineTuning() { string fineTunedModel = GetFineTunedModel(); @@ -69,7 +74,7 @@ public async Task CheckpointsFineTuning() Assert.That(job, Is.Not.Null); Assert.That(job!.Status, Is.EqualTo("succeeded")); - FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(client, job.ID); + FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(UnWrap(client), job.ID); int count = 25; await foreach (FineTuningCheckpoint checkpoint in EnumerateCheckpoints(fineTuningJobOperation)) @@ -95,7 +100,6 @@ public async Task CheckpointsFineTuning() } [RecordedTest] - [Ignore("Disable pending resolution of test framework interception overriding Rehydrateoverrides")] public async Task EventsFineTuning() { string fineTunedModel = GetFineTunedModel(); @@ -109,7 +113,8 @@ public async Task EventsFineTuning() HashSet ids = new(); - FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(client, job.ID); + //TODO fix unwrapping so you don't have to unwrap here. + FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(UnWrap(client), job.ID); int count = 25; var asyncEnum = EnumerateAsync((after, limit, opt) => fineTuningJobOperation.GetJobEventsAsync(after, limit, opt)); @@ -133,7 +138,6 @@ public async Task EventsFineTuning() } [RecordedTest] - [Ignore("Disabling pending resolution of AOAI quota issues and re-recording")] public async Task CreateAndCancelFineTuning() { var fineTuningFile = Assets.FineTuning; @@ -141,8 +145,24 @@ public async Task CreateAndCancelFineTuning() FineTuningClient client = GetTestClient(); OpenAIFileClient fileClient = GetTestClientFrom(client); - // upload training data - OpenAIFile uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath); + OpenAIFile uploadedFile; + try + { + ClientResult fileResult = await fileClient.GetFileAsync("file-db5f5bfe5ea04ffcaeba89947a872828", new RequestOptions() { }); + uploadedFile = ValidateAndParse(fileResult); + } + catch (ClientResultException e) + { + if (e.Message.Contains("ResourceNotFound")) + { + // upload training data + uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath); + } + else + { + throw; + } + } // Create the fine tuning job using var requestContent = new FineTuningOptions() @@ -194,9 +214,7 @@ public async Task CreateAndCancelFineTuning() Assert.True(operation.HasCompleted); } - [RecordedTest(AutomaticRecord = false)] - [Category("LongRunning")] // CAUTION: This test can take up 30 *minutes* to run in live mode - [Ignore("Disabled pending investigation of 404")] + [RecordedTest] public async Task CreateAndDeleteFineTuning() { var fineTuningFile = Assets.FineTuning; @@ -226,7 +244,12 @@ public async Task CreateAndDeleteFineTuning() using var requestContent = new FineTuningOptions() { Model = client.DeploymentOrThrow(), - TrainingFile = uploadedFile.Id + TrainingFile = uploadedFile.Id, + Hyperparameters = new FineTuningHyperparameters() + { + NumEpochs = 1, + BatchSize = 11 + } }.ToBinaryContent(); FineTuningJobOperation operation = await client.CreateFineTuningJobAsync(requestContent, waitUntilCompleted: false); @@ -234,12 +257,12 @@ public async Task CreateAndDeleteFineTuning() Assert.That(job.ID, Is.Not.Null.Or.Empty); Assert.That(job.Error, Is.Null); Assert.That(job.Status, !(Is.Null.Or.EqualTo("failed").Or.EqualTo("cancelled"))); + await operation.CancelAsync(options: null); // Wait for the fine tuning to complete await operation.WaitForCompletionAsync(); job = ValidateAndParse(await operation.GetJobAsync(null)); - Assert.That(job.Status, Is.EqualTo("succeeded"), "Fine tuning did not succeed"); - Assert.That(job.FineTunedModel, Is.Not.Null.Or.Empty); + Assert.That(job.Status, Is.EqualTo("cancelled"), "Fine tuning did not cancel"); // Delete the fine tuned model bool deleted = await DeleteJobAndVerifyAsync((AzureFineTuningJobOperation)operation, job.ID); diff --git a/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs index e074e60debb47..3443a19471fc3 100644 --- a/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs +++ b/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs @@ -176,7 +176,7 @@ public async Task CanAssociateFiles() Assert.True(removalResult.Removed); // Errata: removals aren't immediately reflected when requesting the list - Thread.Sleep(1000); + await Task.Delay(TimeSpan.FromSeconds(5)); int count = 0; AsyncCollectionResult response = client.GetFileAssociationsAsync(vectorStore.Id); diff --git a/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs b/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs index d5d3edff45f01..26d798e2e49b9 100644 --- a/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs +++ b/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs @@ -22,7 +22,11 @@ public MockStringServiceTests(bool isAsync) RecordingOptions.SanitizersToRemove.Add("AZSDK3430"); // $..id } - public DirectoryInfo RepositoryRoot { get; } = FindRepoRoot(); + public DirectoryInfo RepositoryRoot { get; } = FindFirstParentWithSubfolders(".git") + ?? throw new InvalidOperationException("Could not find your Git repository root folder"); + + public DirectoryInfo SourceRoot { get; } = FindFirstParentWithSubfolders("eng", "sdk") + ?? throw new InvalidOperationException("Could not find your source root folder"); [Test] public async Task AddAndGet() @@ -71,7 +75,7 @@ protected override ProxyServiceOptions CreateProxyServiceOptions() DotnetExecutable = AssemblyHelper.GetDotnetExecutable()?.FullName!, TestProxyDll = AssemblyHelper.GetAssemblyMetadata("TestProxyPath")!, DevCertFile = Path.Combine( - RepositoryRoot.FullName, + SourceRoot.FullName, "eng", "common", "testproxy", @@ -91,25 +95,27 @@ protected override RecordingStartInformation CreateRecordingSessionStartInfo() #region helper methods - private static DirectoryInfo FindRepoRoot() + private static DirectoryInfo? FindFirstParentWithSubfolders(params string[] subFolders) { - /** - * This code assumes that we are running in the standard Azure .Net SDK repository layout. With this in mind, - * we generally assume that we are running our test code from - * /artifacts/bin/// - * So to find the root we keep navigating up until we find a folder with a .git subfolder - * - * Another alternative would be to call: git rev-parse --show-toplevel - */ - - DirectoryInfo? current = new FileInfo(Assembly.GetExecutingAssembly().Location).Directory; - while (current != null && !current.EnumerateDirectories(".git").Any()) + if (subFolders == null || subFolders.Length == 0) { - current = current.Parent; + return null; } - return current - ?? throw new InvalidOperationException("Could not determine the root folder for this repository"); + DirectoryInfo? start = new FileInfo(Assembly.GetExecutingAssembly().Location).Directory; + for (DirectoryInfo? current = start; current != null; current = current.Parent) + { + if (!current.Exists) + { + return null; + } + else if (subFolders.All(sub => current.EnumerateDirectories(sub).Any())) + { + return current; + } + } + + return null; } private string GetRecordingFile()