diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props
index df417b1aaa291..dd5727d983b61 100644
--- a/eng/Packages.Data.props
+++ b/eng/Packages.Data.props
@@ -184,7 +184,7 @@
-
+
diff --git a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md
index c8b78cea80db8..803541cd70e03 100644
--- a/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md
+++ b/sdk/openai/Azure.AI.OpenAI/CHANGELOG.md
@@ -1,5 +1,35 @@
# Release History
+## 2.1.0-beta.2 (2024-11-04)
+
+This update brings compatibility with the Azure OpenAI `2024-10-01-preview` service API version as well as the `2.1.0-beta.2` release of the `OpenAI` library.
+
+### Breaking Changes
+
+- `[Experimental]` `ChatCitation` and `ChatRetrievedDocument` have each replaced the `Uri` property of type `System.Uri` with a `string` property named `Url`. This aligns with the REST specification and accounts for the wire value of `url` not always providing a valid RFC 3986 identifier [[azure-sdk-for-net \#46793](https://github.com/Azure/azure-sdk-for-net/issues/46793)]
+
+### Features Added
+
+- The included update via `2024-09-01-preview` brings AOAI support for streaming token usage in chat completions; `Usage` is now automatically populated in `StreamingChatCompletionUpdate` instances.
+ - Note 1: this feature is not yet compatible when using On Your Data features (after invoking the `.AddDataSource()` extension method on `ChatCompletionOptions`)
+ - Note 2: this feature is not yet compatible when using image input (a `ChatMessageContentPart` of `Kind` `Image`)
+- `2024-10-01-preview` further adds support for ungrounded content detection in chat completion content filter results via the `UngroundedMaterial` property on `ResponseContentFilterResult`, as retrieved from a chat completion via the `GetResponseContentFilterResult()` extension method.
+
+Via `OpenAI 2.0.0-beta.2`:
+
+- Made improvements to the experimental Realtime API. Please note this features area is currently under rapid development and not all changes may be reflected here.
+ - Several types have been renamed for consistency and clarity.
+ - ConversationRateLimitsUpdate (previously ConversationRateLimitsUpdatedUpdate) now includes named RequestDetails and TokenDetails properties, mapping to the corresponding named items in the underlying rate_limits command payload.
+
+### Bugs Fixed
+
+- Addressed an HTTP 401 issue that caused certain connection retry attempts, such as those triggered for HTTP 429 rate limiting errors, to sometimes generate a malformed request with multiple `Authorization` headers that would then be rejected. [#46401](https://github.com/Azure/azure-sdk-for-net/pull/46401)
+- Addressed an issue that caused `ChatCitation` and `ChatRetrievedDocument` to sometimes throw on deserialization, specifically when a returned value in the `url` JSON field was not populated with an RFC 3986 compliant identifier for `System.Uri` [[azure-sdk-for-net \#46793](https://github.com/Azure/azure-sdk-for-net/issues/46793)]
+
+Via `OpenAI 2.0.0-beta.2`:
+
+- Fixed serialization and deserialization of ConversationToolChoice literal values (such as "required").
+
## 2.1.0-beta.1 (2024-10-01)
Relative to the prior GA release, this update restores preview surfaces, retargeting to the latest `2024-08-01-preview` service `api-version` label. It also brings early support for the newly-announced `/realtime` capabilities with `gpt-4o-realtime-preview`. You can read more about Azure OpenAI support for `/realtime` in the annoucement post here: https://azure.microsoft.com/blog/announcing-new-products-and-features-for-azure-openai-service-including-gpt-4o-realtime-preview-with-audio-and-speech-capabilities/
@@ -10,7 +40,7 @@ Relative to the prior GA release, this update restores preview surfaces, retarge
- This maps to the new `/realtime` beta endpoint and is thus marked with a new `[Experimental("OPENAI002")]` diagnostic tag.
- This is a very early version of the convenience surface and thus subject to significant change
- Documentation and samples will arrive soon; in the interim, see the scenario test files (in `/tests`) for basic usage
- - You can also find an external sample employing this client, together with Azure OpenAI support, at https://github.com/Azure-Samples/aoai-realtime-audio-sdk/tree/main/dotnet/samples/console
+ - You can also find an external sample employing this client, together with Azure OpenAI support, at https://github.com/Azure-Samples/aoai-realtime-audio-sdk/tree/main/dotnet/samples
## 2.0.0 (2024-09-30)
diff --git a/sdk/openai/Azure.AI.OpenAI/Directory.Build.props b/sdk/openai/Azure.AI.OpenAI/Directory.Build.props
index 8f8f47e265927..51ec75d706d50 100644
--- a/sdk/openai/Azure.AI.OpenAI/Directory.Build.props
+++ b/sdk/openai/Azure.AI.OpenAI/Directory.Build.props
@@ -1,5 +1,6 @@
+ 2.1.0-beta.2
true
$(RequiredTargetFrameworks)
true
@@ -26,14 +27,8 @@
AZURE_OPENAI_GA
-
-
-
- beta.1
-
-
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs
index 0fff6772293d6..54838af00b039 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClientOptions.cs
@@ -56,6 +56,7 @@ public AzureOpenAIClientOptions(ServiceVersion version = LatestVersion)
{
#if !AZURE_OPENAI_GA
ServiceVersion.V2024_08_01_Preview => "2024-08-01-preview",
+ ServiceVersion.V2024_09_01_Preview => "2024-09-01-preview",
ServiceVersion.V2024_10_01_Preview => "2024-10-01-preview",
#endif
ServiceVersion.V2024_06_01 => "2024-06-01",
@@ -70,6 +71,7 @@ public enum ServiceVersion
V2024_06_01 = 0,
#if !AZURE_OPENAI_GA
V2024_08_01_Preview = 1,
+ V2024_09_01_Preview = 2,
V2024_10_01_Preview = 3,
#endif
}
@@ -103,7 +105,7 @@ protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount)
}
#if !AZURE_OPENAI_GA
- private const ServiceVersion LatestVersion = ServiceVersion.V2024_08_01_Preview;
+ private const ServiceVersion LatestVersion = ServiceVersion.V2024_10_01_Preview;
#else
private const ServiceVersion LatestVersion = ServiceVersion.V2024_06_01;
#endif
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs
index 036b82e222794..7d5c692ab45d0 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatClient.cs
@@ -1,11 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+using Azure.AI.OpenAI.Internal;
using OpenAI.Chat;
using System.ClientModel;
using System.ClientModel.Primitives;
+using System.Data.SqlTypes;
using System.Diagnostics.CodeAnalysis;
+#pragma warning disable AOAI001
#pragma warning disable AZC0112
namespace Azure.AI.OpenAI.Chat;
@@ -63,7 +66,7 @@ public override CollectionResult CompleteChatStre
///
public override AsyncCollectionResult CompleteChatStreamingAsync(IEnumerable messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
- PostfixClearStreamOptions(ref options);
+ PostfixClearStreamOptions(messages, ref options);
PostfixSwapMaxTokens(ref options);
return base.CompleteChatStreamingAsync(messages, options, cancellationToken);
}
@@ -71,25 +74,78 @@ public override AsyncCollectionResult CompleteCha
///
public override CollectionResult CompleteChatStreaming(IEnumerable messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
- PostfixClearStreamOptions(ref options);
+ PostfixClearStreamOptions(messages, ref options);
PostfixSwapMaxTokens(ref options);
return base.CompleteChatStreaming(messages, options, cancellationToken);
}
- private static void PostfixClearStreamOptions(ref ChatCompletionOptions options)
+ /**
+ * As of 2024-09-01-preview, stream_options support for include_usage (which reports token usage while streaming)
+ * is conditionally supported:
+ * - When using On Your Data (non-null data_sources), stream_options is not considered valid
+ * - When using image input (any content part of "image" type), stream_options is not considered valid
+ * - Otherwise, stream_options can be defaulted to enabled per parity surface.
+ */
+ private static void PostfixClearStreamOptions(IEnumerable messages, ref ChatCompletionOptions options)
{
- options ??= new();
- options.StreamOptions = null;
+ if (AdditionalPropertyHelpers
+ .GetAdditionalListProperty(options?.SerializedAdditionalRawData, "data_sources")?.Count > 0
+ || messages?.Any(
+ message => message?.Content?.Any(
+ contentPart => contentPart?.Kind == ChatMessageContentPartKind.Image) == true)
+ == true)
+ {
+ options ??= new();
+ options.StreamOptions = null;
+ }
}
+ /**
+ * As of 2024-09-01-preview, Azure OpenAI conditionally supports the use of the new max_completion_tokens property:
+ * - The o1-mini and o1-preview models accept max_completion_tokens and reject max_tokens
+ * - All other models reject max_completion_tokens and accept max_tokens
+ * To handle this, each request will manipulate serialization overrides:
+ * - If max tokens aren't set, no action is taken
+ * - If serialization of max_tokens has already been blocked (e.g. via the public extension method), no
+ * additional logic is used and new serialization to max_completion_tokens will occur
+ * - Otherwise, serialization of max_completion_tokens is blocked and an override serialization of the
+ * corresponding max_tokens value is established
+ */
private static void PostfixSwapMaxTokens(ref ChatCompletionOptions options)
{
options ??= new();
- if (options.MaxOutputTokenCount is not null)
+ bool valueIsSet = options.MaxOutputTokenCount is not null;
+ bool oldPropertyBlocked = AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens");
+
+ if (valueIsSet)
+ {
+ if (!oldPropertyBlocked)
+ {
+ options.SerializedAdditionalRawData ??= new ChangeTrackingDictionary();
+ AdditionalPropertyHelpers.SetEmptySentinelValue(options.SerializedAdditionalRawData, "max_completion_tokens");
+ options.SerializedAdditionalRawData["max_tokens"] = BinaryData.FromObjectAsJson(options.MaxOutputTokenCount);
+ }
+ else
+ {
+ // Allow standard serialization to the new property to occur; remove overrides
+ if (options.SerializedAdditionalRawData.ContainsKey("max_completion_tokens"))
+ {
+ options.SerializedAdditionalRawData.Remove("max_completion_tokens");
+ }
+ }
+ }
+ else
{
- options.SerializedAdditionalRawData ??= new Dictionary();
- options.SerializedAdditionalRawData["max_completion_tokens"] = BinaryData.FromObjectAsJson("__EMPTY__");
- options.SerializedAdditionalRawData["max_tokens"] = BinaryData.FromObjectAsJson(options.MaxOutputTokenCount);
+ if (!AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens")
+ && options.SerializedAdditionalRawData?.ContainsKey("max_tokens") == true)
+ {
+ options.SerializedAdditionalRawData.Remove("max_tokens");
+ }
+ if (!AdditionalPropertyHelpers.GetIsEmptySentinelValue(options.SerializedAdditionalRawData, "max_completion_tokens")
+ && options.SerializedAdditionalRawData?.ContainsKey("max_completion_tokens") == true)
+ {
+ options.SerializedAdditionalRawData.Remove("max_completion_tokens");
+ }
}
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs
index dc88aed8299de..e53541fe1da26 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/AzureChatExtensions.cs
@@ -36,6 +36,25 @@ public static IReadOnlyList GetDataSources(this ChatCompletionOp
"data_sources") as IReadOnlyList;
}
+ [Experimental("AOAI001")]
+ public static void SetNewMaxCompletionTokensPropertyEnabled(this ChatCompletionOptions options, bool newPropertyEnabled = true)
+ {
+ if (newPropertyEnabled)
+ {
+ // Blocking serialization of max_tokens via dictionary acts as a signal to skip pre-serialization fixup
+ AdditionalPropertyHelpers.SetEmptySentinelValue(options.SerializedAdditionalRawData, "max_tokens");
+ }
+ else
+ {
+ // In the absence of a dictionary serialization block to max_tokens, the newer property name will
+ // automatically be blocked and the older property name will be used via dictionary override
+ if (options?.SerializedAdditionalRawData?.ContainsKey("max_tokens") == true)
+ {
+ options?.SerializedAdditionalRawData?.Remove("max_tokens");
+ }
+ }
+ }
+
[Experimental("AOAI001")]
public static RequestContentFilterResult GetRequestContentFilterResult(this ChatCompletion chatCompletion)
{
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs
index e0dc3895e0718..2dca2d7d694c8 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatCitation.cs
@@ -12,9 +12,6 @@ namespace Azure.AI.OpenAI.Chat;
public partial class ChatCitation
{
- /// The location of the citation.
- [CodeGenMember("Url")]
- public Uri Uri { get; }
/// The file path for the citation.
[CodeGenMember("Filepath")]
public string FilePath { get; }
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs
index b0703091af032..74d841304ba79 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Chat/OnYourData/ChatRetrievedDocument.cs
@@ -14,8 +14,4 @@ public partial class ChatRetrievedDocument
/// The file path for the citation.
[CodeGenMember("Filepath")]
public string FilePath { get; }
-
- /// The location of the citation.
- [CodeGenMember("Url")]
- public Uri Uri { get; }
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs
index bc0c0a10b3c7a..6086bdfdacf68 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/AdditionalPropertyHelpers.cs
@@ -8,6 +8,8 @@ namespace Azure.AI.OpenAI.Internal;
internal static class AdditionalPropertyHelpers
{
+ private static string SARD_EMPTY_SENTINEL = "__EMPTY__";
+
internal static T GetAdditionalProperty(IDictionary additionalProperties, string key)
where T : class, IJsonModel
{
@@ -45,4 +47,17 @@ internal static void SetAdditionalProperty(IDictionary ad
BinaryData binaryValue = BinaryData.FromStream(stream);
additionalProperties[key] = binaryValue;
}
+
+ internal static void SetEmptySentinelValue(IDictionary additionalProperties, string key)
+ {
+ Argument.AssertNotNull(additionalProperties, nameof(additionalProperties));
+ additionalProperties[key] = BinaryData.FromObjectAsJson(SARD_EMPTY_SENTINEL);
+ }
+
+ internal static bool GetIsEmptySentinelValue(IDictionary additionalProperties, string key)
+ {
+ return additionalProperties is not null
+ && additionalProperties.TryGetValue(key, out BinaryData existingValue)
+ && StringComparer.OrdinalIgnoreCase.Equals(existingValue.ToString(), $@"""{SARD_EMPTY_SENTINEL}""");
+ }
}
\ No newline at end of file
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs
new file mode 100644
index 0000000000000..d1edc581f2768
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/ContentFilterTextSpanResult.cs
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System.Diagnostics.CodeAnalysis;
+
+namespace Azure.AI.OpenAI;
+
+[Experimental("AOAI001")]
+[CodeGenModel("AzureContentFilterCompletionTextSpanDetectionResult")]
+public partial class ContentFilterTextSpanResult
+{ }
\ No newline at end of file
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs
index a5504cd496951..ff66fd09eb877 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/Common/GeneratorStubs.cs
@@ -8,3 +8,4 @@ namespace Azure.AI.OpenAI;
[Experimental("AOAI001")][CodeGenModel("AzureContentFilterDetectionResult")] public partial class ContentFilterDetectionResult { }
[Experimental("AOAI001")][CodeGenModel("AzureContentFilterResultForChoiceProtectedMaterialCode")] public partial class ContentFilterProtectedMaterialResult { }
[Experimental("AOAI001")][CodeGenModel("AzureContentFilterSeverityResultSeverity")] public readonly partial struct ContentFilterSeverity { }
+[Experimental("AOAI001")][CodeGenModel("AzureContentFilterCompletionTextSpan")] public partial class ContentFilterTextSpan { }
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs
index 8354486cbaed9..029a7f9ac8d96 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationClient.cs
@@ -52,8 +52,6 @@ private static Uri GetEndpoint(Uri endpoint, string deploymentName, string apiVe
_ => uriBuilder.Scheme,
};
- apiVersion = "2024-10-01-preview";
-
bool isLegacyNoDeployment = string.IsNullOrEmpty(deploymentName);
string requiredPathSuffix = isLegacyNoDeployment ? "realtime" : "openai/realtime";
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs
index aa3d67dbb9369..38953624b0ba5 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.Protocol.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
using Azure.Core;
-using OpenAI.RealtimeConversation;
using System.ClientModel.Primitives;
using System.ComponentModel;
using System.Net.WebSockets;
@@ -14,31 +13,74 @@ internal partial class AzureRealtimeConversationSession : RealtimeConversationSe
[EditorBrowsable(EditorBrowsableState.Never)]
protected internal override async Task ConnectAsync(RequestOptions options)
{
- ClientUriBuilder uriBuilder = new();
- uriBuilder.Reset(_endpoint);
+ WebSocket?.Dispose();
- if (_tokenCredential is not null)
+ // Temporary mitigation of static RPM rate limiting behavior: pending protocol-based rate limit integration with
+ // rate_limits.updated, use a static retry strategy upon receiving a 429 error.
+
+ bool shouldRetryOnFailure = false;
+ int rateLimitRetriesUsed = 0;
+ const int maximumRateLimitRetries = 3;
+ TimeSpan timeBetweenRateLimitRetries = TimeSpan.FromSeconds(5);
+
+ do
+ {
+ try
+ {
+ ClientWebSocket webSocket = await CreateAzureWebSocketAsync(options).ConfigureAwait(false);
+ await webSocket.ConnectAsync(_endpoint, options?.CancellationToken ?? default)
+ .ConfigureAwait(false);
+ WebSocket = webSocket;
+ }
+ catch (WebSocketException webSocketException)
+ {
+ shouldRetryOnFailure
+ = webSocketException.Message?.Contains("429") == true && rateLimitRetriesUsed++ < maximumRateLimitRetries;
+ if (shouldRetryOnFailure)
+ {
+ await Task.Delay(timeBetweenRateLimitRetries).ConfigureAwait(false);
+ }
+ else
+ {
+ throw;
+ }
+ }
+ } while (shouldRetryOnFailure);
+ }
+
+ private async Task CreateAzureWebSocketAsync(RequestOptions options)
+ {
+ string clientRequestId = Guid.NewGuid().ToString();
+
+ ClientWebSocket clientWebSocket = new();
+ clientWebSocket.Options.AddSubProtocol("realtime");
+ clientWebSocket.Options.SetRequestHeader("openai-beta", $"realtime=v1");
+ clientWebSocket.Options.SetRequestHeader("x-ms-client-request-id", clientRequestId);
+
+ try
{
- AccessToken token = await _tokenCredential.GetTokenAsync(_tokenRequestContext, options?.CancellationToken ?? default).ConfigureAwait(false);
- _clientWebSocket.Options.SetRequestHeader("Authorization", $"Bearer {token.Token}");
+ clientWebSocket.Options.SetRequestHeader("User-Agent", _userAgent);
}
- else
+ catch (ArgumentException argumentException)
{
- _keyCredential.Deconstruct(out string dangerousCredential);
- _clientWebSocket.Options.SetRequestHeader("api-key", dangerousCredential);
- // uriBuilder.AppendQuery("api-key", dangerousCredential, escape: false);
+ throw new PlatformNotSupportedException(
+ $"{nameof(RealtimeConversationClient)} is not yet supported on older .NET framework targets.",
+ argumentException);
}
- Uri endpoint = uriBuilder.ToUri();
-
- try
+ if (_tokenCredential is not null)
{
- await _clientWebSocket.ConnectAsync(endpoint, options?.CancellationToken ?? default)
+ TokenRequestContext tokenRequestContext = new(_tokenAuthorizationScopes.ToArray(), clientRequestId);
+ AccessToken token = await _tokenCredential.GetTokenAsync(tokenRequestContext, options?.CancellationToken ?? default)
.ConfigureAwait(false);
+ clientWebSocket.Options.SetRequestHeader("Authorization", $"Bearer {token.Token}");
}
- catch (WebSocketException)
+ else
{
- throw;
+ _keyCredential.Deconstruct(out string dangerousCredential);
+ clientWebSocket.Options.SetRequestHeader("api-key", dangerousCredential);
}
+
+ return clientWebSocket;
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs
index 1ee9413dfe952..0fe51b40d754a 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Custom/RealtimeConversation/AzureRealtimeConversationSession.cs
@@ -3,11 +3,10 @@
using System.ClientModel;
using System.ClientModel.Primitives;
-using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Net;
+using System.Net.WebSockets;
using Azure.Core;
-using OpenAI.RealtimeConversation;
namespace Azure.AI.OpenAI.RealtimeConversation;
@@ -18,8 +17,7 @@ internal partial class AzureRealtimeConversationSession : RealtimeConversationSe
private readonly ApiKeyCredential _keyCredential;
private readonly TokenCredential _tokenCredential;
private readonly IEnumerable _tokenAuthorizationScopes;
- private readonly TokenRequestContext _tokenRequestContext;
- private readonly string _clientRequestId;
+ private readonly string _userAgent;
protected internal AzureRealtimeConversationSession(
AzureRealtimeConversationClient parentClient,
@@ -41,33 +39,12 @@ protected internal AzureRealtimeConversationSession(
{
_tokenCredential = credential;
_tokenAuthorizationScopes = tokenAuthorizationScopes;
- _tokenRequestContext = new(_tokenAuthorizationScopes.ToArray(), parentRequestId: _clientRequestId);
}
private AzureRealtimeConversationSession(AzureRealtimeConversationClient parentClient, Uri endpoint, string userAgent)
: base(parentClient, endpoint, credential: new("placeholder"))
{
- _clientRequestId = Guid.NewGuid().ToString();
-
_endpoint = endpoint;
- _clientWebSocket.Options.AddSubProtocol("realtime");
- _clientWebSocket.Options.SetRequestHeader("User-Agent", userAgent);
- _clientWebSocket.Options.SetRequestHeader("x-ms-client-request-id", _clientRequestId);
- }
-
- internal override async Task SendCommandAsync(InternalRealtimeRequestCommand command, CancellationToken cancellationToken = default)
- {
- BinaryData requestData = ModelReaderWriter.Write(command);
-
- // Temporary backcompat quirk
- if (command is InternalRealtimeRequestSessionUpdateCommand sessionUpdateCommand
- && sessionUpdateCommand.Session?.TurnDetectionOptions is InternalRealtimeNoTurnDetection)
- {
- requestData = BinaryData.FromString(requestData.ToString()
- .Replace(@"""turn_detection"":null", @"""turn_detection"":{""type"":""none""}"));
- }
-
- RequestOptions cancellationOptions = cancellationToken.ToRequestOptions();
- await SendCommandAsync(requestData, cancellationOptions).ConfigureAwait(false);
+ _userAgent = userAgent;
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs
index d361f85eecfde..ec636d3186115 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.Serialization.cs
@@ -31,10 +31,10 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOpti
writer.WritePropertyName("title"u8);
writer.WriteStringValue(Title);
}
- if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Uri))
+ if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Url))
{
writer.WritePropertyName("url"u8);
- writer.WriteStringValue(Uri.AbsoluteUri);
+ writer.WriteStringValue(Url);
}
if (SerializedAdditionalRawData?.ContainsKey("filepath") != true && Optional.IsDefined(FilePath))
{
@@ -95,7 +95,7 @@ internal static ChatCitation DeserializeChatCitation(JsonElement element, ModelR
}
string content = default;
string title = default;
- Uri url = default;
+ string url = default;
string filepath = default;
string chunkId = default;
double? rerankScore = default;
@@ -115,11 +115,7 @@ internal static ChatCitation DeserializeChatCitation(JsonElement element, ModelR
}
if (property.NameEquals("url"u8))
{
- if (property.Value.ValueKind == JsonValueKind.Null)
- {
- continue;
- }
- url = new Uri(property.Value.GetString());
+ url = property.Value.GetString();
continue;
}
if (property.NameEquals("filepath"u8))
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs
index 6d49dbb6950e9..74cd970915eab 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatCitation.cs
@@ -54,16 +54,16 @@ internal ChatCitation(string content)
/// Initializes a new instance of .
/// The content of the citation.
/// The title for the citation.
- /// The URL of the citation.
+ /// The URL of the citation.
/// The file path for the citation.
/// The chunk ID for the citation.
/// The rerank score for the retrieval.
/// Keeps track of any properties unknown to the library.
- internal ChatCitation(string content, string title, Uri uri, string filePath, string chunkId, double? rerankScore, IDictionary serializedAdditionalRawData)
+ internal ChatCitation(string content, string title, string url, string filePath, string chunkId, double? rerankScore, IDictionary serializedAdditionalRawData)
{
Content = content;
Title = title;
- Uri = uri;
+ Url = url;
FilePath = filePath;
ChunkId = chunkId;
RerankScore = rerankScore;
@@ -79,6 +79,8 @@ internal ChatCitation()
public string Content { get; }
/// The title for the citation.
public string Title { get; }
+ /// The URL of the citation.
+ public string Url { get; }
/// The chunk ID for the citation.
public string ChunkId { get; }
/// The rerank score for the retrieval.
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs
index 2ab576c95b612..5f62fdce58d67 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.Serialization.cs
@@ -31,10 +31,10 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderW
writer.WritePropertyName("title"u8);
writer.WriteStringValue(Title);
}
- if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Uri))
+ if (SerializedAdditionalRawData?.ContainsKey("url") != true && Optional.IsDefined(Url))
{
writer.WritePropertyName("url"u8);
- writer.WriteStringValue(Uri.AbsoluteUri);
+ writer.WriteStringValue(Url);
}
if (SerializedAdditionalRawData?.ContainsKey("filepath") != true && Optional.IsDefined(FilePath))
{
@@ -120,7 +120,7 @@ internal static ChatRetrievedDocument DeserializeChatRetrievedDocument(JsonEleme
}
string content = default;
string title = default;
- Uri url = default;
+ string url = default;
string filepath = default;
string chunkId = default;
double? rerankScore = default;
@@ -144,11 +144,7 @@ internal static ChatRetrievedDocument DeserializeChatRetrievedDocument(JsonEleme
}
if (property.NameEquals("url"u8))
{
- if (property.Value.ValueKind == JsonValueKind.Null)
- {
- continue;
- }
- url = new Uri(property.Value.GetString());
+ url = property.Value.GetString();
continue;
}
if (property.NameEquals("filepath"u8))
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs
index 96e966ba79e03..145d0be9adf6c 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ChatRetrievedDocument.cs
@@ -60,7 +60,7 @@ internal ChatRetrievedDocument(string content, IEnumerable searchQueries
/// Initializes a new instance of .
/// The content of the citation.
/// The title for the citation.
- /// The URL of the citation.
+ /// The URL of the citation.
/// The file path for the citation.
/// The chunk ID for the citation.
/// The rerank score for the retrieval.
@@ -69,11 +69,11 @@ internal ChatRetrievedDocument(string content, IEnumerable searchQueries
/// The original search score for the retrieval.
/// If applicable, an indication of why the document was filtered.
/// Keeps track of any properties unknown to the library.
- internal ChatRetrievedDocument(string content, string title, Uri uri, string filePath, string chunkId, double? rerankScore, IReadOnlyList searchQueries, int dataSourceIndex, double? originalSearchScore, ChatDocumentFilterReason? filterReason, IDictionary serializedAdditionalRawData)
+ internal ChatRetrievedDocument(string content, string title, string url, string filePath, string chunkId, double? rerankScore, IReadOnlyList searchQueries, int dataSourceIndex, double? originalSearchScore, ChatDocumentFilterReason? filterReason, IDictionary serializedAdditionalRawData)
{
Content = content;
Title = title;
- Uri = uri;
+ Url = url;
FilePath = filePath;
ChunkId = chunkId;
RerankScore = rerankScore;
@@ -93,6 +93,8 @@ internal ChatRetrievedDocument()
public string Content { get; }
/// The title for the citation.
public string Title { get; }
+ /// The URL of the citation.
+ public string Url { get; }
/// The chunk ID for the citation.
public string ChunkId { get; }
/// The rerank score for the retrieval.
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs
new file mode 100644
index 0000000000000..98a048d355390
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.Serialization.cs
@@ -0,0 +1,147 @@
+//
+
+#nullable disable
+
+using System;
+using System.ClientModel;
+using System.ClientModel.Primitives;
+using System.Collections.Generic;
+using System.Text.Json;
+
+namespace Azure.AI.OpenAI
+{
+ public partial class ContentFilterTextSpan : IJsonModel
+ {
+ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+ if (format != "J")
+ {
+ throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support writing '{format}' format.");
+ }
+
+ writer.WriteStartObject();
+ if (SerializedAdditionalRawData?.ContainsKey("completion_start_offset") != true)
+ {
+ writer.WritePropertyName("completion_start_offset"u8);
+ writer.WriteNumberValue(CompletionStartOffset);
+ }
+ if (SerializedAdditionalRawData?.ContainsKey("completion_end_offset") != true)
+ {
+ writer.WritePropertyName("completion_end_offset"u8);
+ writer.WriteNumberValue(CompletionEndOffset);
+ }
+ if (SerializedAdditionalRawData != null)
+ {
+ foreach (var item in SerializedAdditionalRawData)
+ {
+ if (ModelSerializationExtensions.IsSentinelValue(item.Value))
+ {
+ continue;
+ }
+ writer.WritePropertyName(item.Key);
+#if NET6_0_OR_GREATER
+ writer.WriteRawValue(item.Value);
+#else
+ using (JsonDocument document = JsonDocument.Parse(item.Value))
+ {
+ JsonSerializer.Serialize(writer, document.RootElement);
+ }
+#endif
+ }
+ }
+ writer.WriteEndObject();
+ }
+
+ ContentFilterTextSpan IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+ if (format != "J")
+ {
+ throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support reading '{format}' format.");
+ }
+
+ using JsonDocument document = JsonDocument.ParseValue(ref reader);
+ return DeserializeContentFilterTextSpan(document.RootElement, options);
+ }
+
+ internal static ContentFilterTextSpan DeserializeContentFilterTextSpan(JsonElement element, ModelReaderWriterOptions options = null)
+ {
+ options ??= ModelSerializationExtensions.WireOptions;
+
+ if (element.ValueKind == JsonValueKind.Null)
+ {
+ return null;
+ }
+ int completionStartOffset = default;
+ int completionEndOffset = default;
+ IDictionary serializedAdditionalRawData = default;
+ Dictionary rawDataDictionary = new Dictionary();
+ foreach (var property in element.EnumerateObject())
+ {
+ if (property.NameEquals("completion_start_offset"u8))
+ {
+ completionStartOffset = property.Value.GetInt32();
+ continue;
+ }
+ if (property.NameEquals("completion_end_offset"u8))
+ {
+ completionEndOffset = property.Value.GetInt32();
+ continue;
+ }
+ if (options.Format != "W")
+ {
+ rawDataDictionary ??= new Dictionary();
+ rawDataDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText()));
+ }
+ }
+ serializedAdditionalRawData = rawDataDictionary;
+ return new ContentFilterTextSpan(completionStartOffset, completionEndOffset, serializedAdditionalRawData);
+ }
+
+ BinaryData IPersistableModel.Write(ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+
+ switch (format)
+ {
+ case "J":
+ return ModelReaderWriter.Write(this, options);
+ default:
+ throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support writing '{options.Format}' format.");
+ }
+ }
+
+ ContentFilterTextSpan IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+
+ switch (format)
+ {
+ case "J":
+ {
+ using JsonDocument document = JsonDocument.Parse(data);
+ return DeserializeContentFilterTextSpan(document.RootElement, options);
+ }
+ default:
+ throw new FormatException($"The model {nameof(ContentFilterTextSpan)} does not support reading '{options.Format}' format.");
+ }
+ }
+
+ string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J";
+
+ /// Deserializes the model from a raw response.
+ /// The result to deserialize the model from.
+ internal static ContentFilterTextSpan FromResponse(PipelineResponse response)
+ {
+ using var document = JsonDocument.Parse(response.Content);
+ return DeserializeContentFilterTextSpan(document.RootElement);
+ }
+
+ /// Convert into a .
+ internal virtual BinaryContent ToBinaryContent()
+ {
+ return BinaryContent.Create(this, ModelSerializationExtensions.WireOptions);
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs
new file mode 100644
index 0000000000000..2b04de8d01214
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpan.cs
@@ -0,0 +1,74 @@
+//
+
+#nullable disable
+
+using System;
+using System.Collections.Generic;
+
+namespace Azure.AI.OpenAI
+{
+ /// A representation of a span of completion text as used by Azure OpenAI content filter results.
+ public partial class ContentFilterTextSpan
+ {
+ ///
+ /// Keeps track of any properties unknown to the library.
+ ///
+ /// To assign an object to the value of this property use .
+ ///
+ ///
+ /// To assign an already formatted json string to this property use .
+ ///
+ ///
+ /// Examples:
+ ///
+ /// -
+ /// BinaryData.FromObjectAsJson("foo")
+ /// Creates a payload of "foo".
+ ///
+ /// -
+ /// BinaryData.FromString("\"foo\"")
+ /// Creates a payload of "foo".
+ ///
+ /// -
+ /// BinaryData.FromObjectAsJson(new { key = "value" })
+ /// Creates a payload of { "key": "value" }.
+ ///
+ /// -
+ /// BinaryData.FromString("{\"key\": \"value\"}")
+ /// Creates a payload of { "key": "value" }.
+ ///
+ ///
+ ///
+ ///
+ internal IDictionary SerializedAdditionalRawData { get; set; }
+ /// Initializes a new instance of .
+ /// Offset of the UTF32 code point which begins the span.
+ /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans.
+ internal ContentFilterTextSpan(int completionStartOffset, int completionEndOffset)
+ {
+ CompletionStartOffset = completionStartOffset;
+ CompletionEndOffset = completionEndOffset;
+ }
+
+ /// Initializes a new instance of .
+ /// Offset of the UTF32 code point which begins the span.
+ /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans.
+ /// Keeps track of any properties unknown to the library.
+ internal ContentFilterTextSpan(int completionStartOffset, int completionEndOffset, IDictionary serializedAdditionalRawData)
+ {
+ CompletionStartOffset = completionStartOffset;
+ CompletionEndOffset = completionEndOffset;
+ SerializedAdditionalRawData = serializedAdditionalRawData;
+ }
+
+ /// Initializes a new instance of for deserialization.
+ internal ContentFilterTextSpan()
+ {
+ }
+
+ /// Offset of the UTF32 code point which begins the span.
+ public int CompletionStartOffset { get; }
+ /// Offset of the first UTF32 code point which is excluded from the span. This field is always equal to completion_start_offset for empty spans. This field is always larger than completion_start_offset for non-empty spans.
+ public int CompletionEndOffset { get; }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs
new file mode 100644
index 0000000000000..4e7e99ab7caf4
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.Serialization.cs
@@ -0,0 +1,168 @@
+//
+
+#nullable disable
+
+using System;
+using System.ClientModel;
+using System.ClientModel.Primitives;
+using System.Collections.Generic;
+using System.Text.Json;
+
+namespace Azure.AI.OpenAI
+{
+ public partial class ContentFilterTextSpanResult : IJsonModel
+ {
+ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+ if (format != "J")
+ {
+ throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support writing '{format}' format.");
+ }
+
+ writer.WriteStartObject();
+ if (SerializedAdditionalRawData?.ContainsKey("filtered") != true)
+ {
+ writer.WritePropertyName("filtered"u8);
+ writer.WriteBooleanValue(Filtered);
+ }
+ if (SerializedAdditionalRawData?.ContainsKey("detected") != true)
+ {
+ writer.WritePropertyName("detected"u8);
+ writer.WriteBooleanValue(Detected);
+ }
+ if (SerializedAdditionalRawData?.ContainsKey("details") != true)
+ {
+ writer.WritePropertyName("details"u8);
+ writer.WriteStartArray();
+ foreach (var item in Details)
+ {
+ writer.WriteObjectValue(item, options);
+ }
+ writer.WriteEndArray();
+ }
+ if (SerializedAdditionalRawData != null)
+ {
+ foreach (var item in SerializedAdditionalRawData)
+ {
+ if (ModelSerializationExtensions.IsSentinelValue(item.Value))
+ {
+ continue;
+ }
+ writer.WritePropertyName(item.Key);
+#if NET6_0_OR_GREATER
+ writer.WriteRawValue(item.Value);
+#else
+ using (JsonDocument document = JsonDocument.Parse(item.Value))
+ {
+ JsonSerializer.Serialize(writer, document.RootElement);
+ }
+#endif
+ }
+ }
+ writer.WriteEndObject();
+ }
+
+ ContentFilterTextSpanResult IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+ if (format != "J")
+ {
+ throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support reading '{format}' format.");
+ }
+
+ using JsonDocument document = JsonDocument.ParseValue(ref reader);
+ return DeserializeContentFilterTextSpanResult(document.RootElement, options);
+ }
+
+ internal static ContentFilterTextSpanResult DeserializeContentFilterTextSpanResult(JsonElement element, ModelReaderWriterOptions options = null)
+ {
+ options ??= ModelSerializationExtensions.WireOptions;
+
+ if (element.ValueKind == JsonValueKind.Null)
+ {
+ return null;
+ }
+ bool filtered = default;
+ bool detected = default;
+ IReadOnlyList details = default;
+ IDictionary serializedAdditionalRawData = default;
+ Dictionary rawDataDictionary = new Dictionary();
+ foreach (var property in element.EnumerateObject())
+ {
+ if (property.NameEquals("filtered"u8))
+ {
+ filtered = property.Value.GetBoolean();
+ continue;
+ }
+ if (property.NameEquals("detected"u8))
+ {
+ detected = property.Value.GetBoolean();
+ continue;
+ }
+ if (property.NameEquals("details"u8))
+ {
+ List array = new List();
+ foreach (var item in property.Value.EnumerateArray())
+ {
+ array.Add(ContentFilterTextSpan.DeserializeContentFilterTextSpan(item, options));
+ }
+ details = array;
+ continue;
+ }
+ if (options.Format != "W")
+ {
+ rawDataDictionary ??= new Dictionary();
+ rawDataDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText()));
+ }
+ }
+ serializedAdditionalRawData = rawDataDictionary;
+ return new ContentFilterTextSpanResult(filtered, detected, details, serializedAdditionalRawData);
+ }
+
+ BinaryData IPersistableModel.Write(ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+
+ switch (format)
+ {
+ case "J":
+ return ModelReaderWriter.Write(this, options);
+ default:
+ throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support writing '{options.Format}' format.");
+ }
+ }
+
+ ContentFilterTextSpanResult IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options)
+ {
+ var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format;
+
+ switch (format)
+ {
+ case "J":
+ {
+ using JsonDocument document = JsonDocument.Parse(data);
+ return DeserializeContentFilterTextSpanResult(document.RootElement, options);
+ }
+ default:
+ throw new FormatException($"The model {nameof(ContentFilterTextSpanResult)} does not support reading '{options.Format}' format.");
+ }
+ }
+
+ string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J";
+
+ /// Deserializes the model from a raw response.
+ /// The result to deserialize the model from.
+ internal static ContentFilterTextSpanResult FromResponse(PipelineResponse response)
+ {
+ using var document = JsonDocument.Parse(response.Content);
+ return DeserializeContentFilterTextSpanResult(document.RootElement);
+ }
+
+ /// Convert into a .
+ internal virtual BinaryContent ToBinaryContent()
+ {
+ return BinaryContent.Create(this, ModelSerializationExtensions.WireOptions);
+ }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs
new file mode 100644
index 0000000000000..9f7a3dba1e0fa
--- /dev/null
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ContentFilterTextSpanResult.cs
@@ -0,0 +1,84 @@
+//
+
+#nullable disable
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Azure.AI.OpenAI
+{
+ /// The AzureContentFilterCompletionTextSpanDetectionResult.
+ public partial class ContentFilterTextSpanResult
+ {
+ ///
+ /// Keeps track of any properties unknown to the library.
+ ///
+ /// To assign an object to the value of this property use .
+ ///
+ ///
+ /// To assign an already formatted json string to this property use .
+ ///
+ ///
+ /// Examples:
+ ///
+ /// -
+ /// BinaryData.FromObjectAsJson("foo")
+ /// Creates a payload of "foo".
+ ///
+ /// -
+ /// BinaryData.FromString("\"foo\"")
+ /// Creates a payload of "foo".
+ ///
+ /// -
+ /// BinaryData.FromObjectAsJson(new { key = "value" })
+ /// Creates a payload of { "key": "value" }.
+ ///
+ /// -
+ /// BinaryData.FromString("{\"key\": \"value\"}")
+ /// Creates a payload of { "key": "value" }.
+ ///
+ ///
+ ///
+ ///
+ internal IDictionary SerializedAdditionalRawData { get; set; }
+ /// Initializes a new instance of .
+ /// Whether the content detection resulted in a content filtering action.
+ /// Whether the labeled content category was detected in the content.
+ /// Detailed information about the detected completion text spans.
+ /// is null.
+ internal ContentFilterTextSpanResult(bool filtered, bool detected, IEnumerable details)
+ {
+ Argument.AssertNotNull(details, nameof(details));
+
+ Filtered = filtered;
+ Detected = detected;
+ Details = details.ToList();
+ }
+
+ /// Initializes a new instance of .
+ /// Whether the content detection resulted in a content filtering action.
+ /// Whether the labeled content category was detected in the content.
+ /// Detailed information about the detected completion text spans.
+ /// Keeps track of any properties unknown to the library.
+ internal ContentFilterTextSpanResult(bool filtered, bool detected, IReadOnlyList details, IDictionary serializedAdditionalRawData)
+ {
+ Filtered = filtered;
+ Detected = detected;
+ Details = details;
+ SerializedAdditionalRawData = serializedAdditionalRawData;
+ }
+
+ /// Initializes a new instance of for deserialization.
+ internal ContentFilterTextSpanResult()
+ {
+ }
+
+ /// Whether the content detection resulted in a content filtering action.
+ public bool Filtered { get; }
+ /// Whether the labeled content category was detected in the content.
+ public bool Detected { get; }
+ /// Detailed information about the detected completion text spans.
+ public IReadOnlyList Details { get; }
+ }
+}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs
index 5e9da43d62ebc..7d03d6540b2d1 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.Serialization.cs
@@ -66,6 +66,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelR
writer.WritePropertyName("protected_material_code"u8);
writer.WriteObjectValue(ProtectedMaterialCode, options);
}
+ if (SerializedAdditionalRawData?.ContainsKey("ungrounded_material") != true && Optional.IsDefined(UngroundedMaterial))
+ {
+ writer.WritePropertyName("ungrounded_material"u8);
+ writer.WriteObjectValue(UngroundedMaterial, options);
+ }
if (SerializedAdditionalRawData != null)
{
foreach (var item in SerializedAdditionalRawData)
@@ -117,6 +122,7 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu
InternalAzureContentFilterResultForPromptContentFilterResultsError error = default;
ContentFilterDetectionResult protectedMaterialText = default;
ContentFilterProtectedMaterialResult protectedMaterialCode = default;
+ ContentFilterTextSpanResult ungroundedMaterial = default;
IDictionary serializedAdditionalRawData = default;
Dictionary rawDataDictionary = new Dictionary();
foreach (var property in element.EnumerateObject())
@@ -202,6 +208,15 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu
protectedMaterialCode = ContentFilterProtectedMaterialResult.DeserializeContentFilterProtectedMaterialResult(property.Value, options);
continue;
}
+ if (property.NameEquals("ungrounded_material"u8))
+ {
+ if (property.Value.ValueKind == JsonValueKind.Null)
+ {
+ continue;
+ }
+ ungroundedMaterial = ContentFilterTextSpanResult.DeserializeContentFilterTextSpanResult(property.Value, options);
+ continue;
+ }
if (options.Format != "W")
{
rawDataDictionary ??= new Dictionary();
@@ -219,6 +234,7 @@ internal static ResponseContentFilterResult DeserializeResponseContentFilterResu
error,
protectedMaterialText,
protectedMaterialCode,
+ ungroundedMaterial,
serializedAdditionalRawData);
}
diff --git a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs
index be77b4567b3d0..87fda89bae54b 100644
--- a/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs
+++ b/sdk/openai/Azure.AI.OpenAI/src/Generated/ResponseContentFilterResult.cs
@@ -75,8 +75,9 @@ internal ResponseContentFilterResult()
/// If present, details about an error that prevented content filtering from completing its evaluation.
/// A detection result that describes a match against text protected under copyright or other status.
/// A detection result that describes a match against licensed code or other protected source material.
+ ///
/// Keeps track of any properties unknown to the library.
- internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, ContentFilterSeverityResult hate, ContentFilterSeverityResult violence, ContentFilterSeverityResult selfHarm, ContentFilterDetectionResult profanity, ContentFilterBlocklistResult customBlocklists, InternalAzureContentFilterResultForPromptContentFilterResultsError error, ContentFilterDetectionResult protectedMaterialText, ContentFilterProtectedMaterialResult protectedMaterialCode, IDictionary serializedAdditionalRawData)
+ internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, ContentFilterSeverityResult hate, ContentFilterSeverityResult violence, ContentFilterSeverityResult selfHarm, ContentFilterDetectionResult profanity, ContentFilterBlocklistResult customBlocklists, InternalAzureContentFilterResultForPromptContentFilterResultsError error, ContentFilterDetectionResult protectedMaterialText, ContentFilterProtectedMaterialResult protectedMaterialCode, ContentFilterTextSpanResult ungroundedMaterial, IDictionary serializedAdditionalRawData)
{
Sexual = sexual;
Hate = hate;
@@ -87,6 +88,7 @@ internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, Content
Error = error;
ProtectedMaterialText = protectedMaterialText;
ProtectedMaterialCode = protectedMaterialCode;
+ UngroundedMaterial = ungroundedMaterial;
SerializedAdditionalRawData = serializedAdditionalRawData;
}
@@ -125,5 +127,7 @@ internal ResponseContentFilterResult(ContentFilterSeverityResult sexual, Content
public ContentFilterDetectionResult ProtectedMaterialText { get; }
/// A detection result that describes a match against licensed code or other protected source material.
public ContentFilterProtectedMaterialResult ProtectedMaterialCode { get; }
+ /// Gets the ungrounded material.
+ public ContentFilterTextSpanResult UngroundedMaterial { get; }
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs
index 86bde5605ab71..d5957f918e3ba 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.Vision.cs
@@ -87,6 +87,7 @@ public async Task ChatWithImagesStreaming(bool useUri)
{
bool foundPromptFilter = false;
bool foundResponseFilter = false;
+ ChatTokenUsage? usage = null;
StringBuilder content = new();
ChatClient client = GetTestClient("vision");
@@ -123,9 +124,11 @@ public async Task ChatWithImagesStreaming(bool useUri)
await foreach (StreamingChatCompletionUpdate update in response)
{
- ValidateUpdate(update, content, ref foundPromptFilter, ref foundResponseFilter);
+ ValidateUpdate(update, content, ref foundPromptFilter, ref foundResponseFilter, ref usage);
}
+ // Assert.That(usage, Is.Not.Null);
+
// TODO FIXME: gpt-4o models seem to return inconsistent prompt filters to skip this for now
//Assert.That(foundPromptFilter, Is.True);
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs
index 096cfc602eede..83105dd0dddb9 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ChatTests.cs
@@ -142,6 +142,69 @@ public void DataSourceSerializationWorks()
Assert.That(sourcesFromOptions[1], Is.InstanceOf());
}
+#if !AZURE_OPENAI_GA
+ [Test]
+ [Category("Smoke")]
+ public async Task MaxTokensSerializationConfigurationWorks()
+ {
+ using MockHttpMessageHandler pipeline = new(MockHttpMessageHandler.ReturnEmptyJson);
+
+ Uri endpoint = new Uri("https://www.bing.com/");
+ string apiKey = "not-a-real-one";
+ string model = "ignore";
+
+ AzureOpenAIClient topLevel = new(
+ endpoint,
+ new ApiKeyCredential(apiKey),
+ new AzureOpenAIClientOptions()
+ {
+ Transport = pipeline.Transport
+ });
+
+ ChatClient client = topLevel.GetChatClient(model);
+
+ ChatCompletionOptions options = new();
+ bool GetSerializedOptionsContains(string value)
+ {
+ BinaryData serialized = ModelReaderWriter.Write(options);
+ return serialized.ToString().Contains(value);
+ }
+ async Task AssertExpectedSerializationAsync(bool hasOldMaxTokens, bool hasNewMaxCompletionTokens)
+ {
+ _ = await client.CompleteChatAsync(["Just mocking, no call here"], options);
+ Assert.That(GetSerializedOptionsContains("max_tokens"), Is.EqualTo(hasOldMaxTokens));
+ Assert.That(GetSerializedOptionsContains("max_completion_tokens"), Is.EqualTo(hasNewMaxCompletionTokens));
+ }
+
+ await AssertExpectedSerializationAsync(false, false);
+ await AssertExpectedSerializationAsync(false, false);
+
+ options.MaxOutputTokenCount = 42;
+ await AssertExpectedSerializationAsync(true, false);
+ await AssertExpectedSerializationAsync(true, false);
+ options.MaxOutputTokenCount = null;
+ await AssertExpectedSerializationAsync(false, false);
+ options.MaxOutputTokenCount = 42;
+ await AssertExpectedSerializationAsync(true, false);
+
+ options.SetNewMaxCompletionTokensPropertyEnabled();
+ await AssertExpectedSerializationAsync(false, true);
+ await AssertExpectedSerializationAsync(false, true);
+ options.MaxOutputTokenCount = null;
+ await AssertExpectedSerializationAsync(false, false);
+ options.MaxOutputTokenCount = 42;
+ await AssertExpectedSerializationAsync(false, true);
+
+ options.SetNewMaxCompletionTokensPropertyEnabled(false);
+ await AssertExpectedSerializationAsync(true, false);
+ await AssertExpectedSerializationAsync(true, false);
+ options.MaxOutputTokenCount = null;
+ await AssertExpectedSerializationAsync(false, false);
+ options.MaxOutputTokenCount = 42;
+ await AssertExpectedSerializationAsync(true, false);
+ }
+#endif
+
[RecordedTest]
public async Task ChatCompletionBadKeyGivesHelpfulError()
{
@@ -162,7 +225,6 @@ public async Task ChatCompletionBadKeyGivesHelpfulError()
}
[RecordedTest]
- [Category("Smoke")]
public async Task DefaultAzureCredentialWorks()
{
ChatClient chatClient = GetTestClient(tokenCredential: this.TestEnvironment.Credential);
@@ -492,6 +554,7 @@ public async Task ChatCompletionStreaming()
StringBuilder builder = new();
bool foundPromptFilter = false;
bool foundResponseFilter = false;
+ ChatTokenUsage? usage = null;
ChatClient chatClient = GetTestClient();
@@ -512,12 +575,14 @@ public async Task ChatCompletionStreaming()
await foreach (StreamingChatCompletionUpdate update in streamingResults)
{
- ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter);
+ ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter, ref usage);
}
string allText = builder.ToString();
Assert.That(allText, Is.Not.Null.Or.Empty);
+ Assert.That(usage, Is.Not.Null);
+
Assert.That(foundPromptFilter, Is.True);
Assert.That(foundResponseFilter, Is.True);
}
@@ -528,6 +593,7 @@ public async Task SearchExtensionWorksStreaming()
StringBuilder builder = new();
bool foundPromptFilter = false;
bool foundResponseFilter = false;
+ ChatTokenUsage? usage = null;
List contexts = new();
var searchConfig = TestConfig.GetConfig("search")!;
@@ -555,7 +621,7 @@ public async Task SearchExtensionWorksStreaming()
await foreach (StreamingChatCompletionUpdate update in chatUpdates)
{
- ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter);
+ ValidateUpdate(update, builder, ref foundPromptFilter, ref foundResponseFilter, ref usage);
ChatMessageContext context = update.GetMessageContext();
if (context != null)
@@ -567,6 +633,8 @@ public async Task SearchExtensionWorksStreaming()
string allText = builder.ToString();
Assert.That(allText, Is.Not.Null.Or.Empty);
+ // Assert.That(usage, Is.Not.Null);
+
// TODO FIXME: When using data sources, the service does not appear to return request nor response filtering information
//Assert.That(foundPromptFilter, Is.True);
//Assert.That(foundResponseFilter, Is.True);
@@ -636,7 +704,7 @@ in client.CompleteChatStreamingAsync(
#endregion
#region Helper methods
- private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder builder, ref bool foundPromptFilter, ref bool foundResponseFilter)
+ private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder builder, ref bool foundPromptFilter, ref bool foundResponseFilter, ref ChatTokenUsage? usage)
{
if (update.CreatedAt == UNIX_EPOCH)
{
@@ -656,6 +724,8 @@ private void ValidateUpdate(StreamingChatCompletionUpdate update, StringBuilder
Assert.That(update.FinishReason, Is.Null.Or.EqualTo(ChatFinishReason.Stop));
if (update.Usage != null)
{
+ Assert.That(usage, Is.Null);
+ usage = update.Usage;
Assert.That(update.Usage.InputTokenCount, Is.GreaterThanOrEqualTo(0));
Assert.That(update.Usage.OutputTokenCount, Is.GreaterThanOrEqualTo(0));
Assert.That(update.Usage.TotalTokenCount, Is.GreaterThanOrEqualTo(0));
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs
index db54c9583c1c0..a3a3194357134 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationSmokeTests.cs
@@ -21,7 +21,7 @@ public void ItemCreation()
{
ConversationItem messageItem = ConversationItem.CreateUserMessage(["Hello, world!"]);
Assert.That(messageItem?.MessageContentParts?.Count, Is.EqualTo(1));
- Assert.That(messageItem.MessageContentParts[0].TextValue, Is.EqualTo("Hello, world!"));
+ Assert.That(messageItem.MessageContentParts[0].Text, Is.EqualTo("Hello, world!"));
}
[Test]
@@ -36,8 +36,7 @@ public void OptionsSerializationWorks()
Model = "whisper-1",
},
Instructions = "test instructions",
- MaxResponseOutputTokens = 42,
- Model = "gpt-4o-realtime-preview",
+ MaxOutputTokens = 42,
OutputAudioFormat = ConversationAudioFormat.G711Ulaw,
Temperature = 0.42f,
ToolChoice = ConversationToolChoice.CreateFunctionToolChoice("test-function"),
@@ -67,7 +66,6 @@ public void OptionsSerializationWorks()
Assert.That(jsonNode["input_audio_transcription"]?["model"]?.GetValue(), Is.EqualTo("whisper-1"));
Assert.That(jsonNode["instructions"]?.GetValue(), Is.EqualTo("test instructions"));
Assert.That(jsonNode["max_response_output_tokens"]?.GetValue(), Is.EqualTo(42));
- Assert.That(jsonNode["model"]?.GetValue(), Is.EqualTo("gpt-4o-realtime-preview"));
Assert.That(jsonNode["output_audio_format"]?.GetValue(), Is.EqualTo("g711_ulaw"));
Assert.That(jsonNode["temperature"]?.GetValue(), Is.EqualTo(0.42f));
Assert.That(jsonNode["tools"]?.AsArray()?.ToList(), Has.Count.EqualTo(1));
@@ -85,8 +83,7 @@ public void OptionsSerializationWorks()
Assert.That(deserializedOptions.InputAudioFormat, Is.EqualTo(ConversationAudioFormat.G711Alaw));
Assert.That(deserializedOptions.InputTranscriptionOptions?.Model, Is.EqualTo(ConversationTranscriptionModel.Whisper1));
Assert.That(deserializedOptions.Instructions, Is.EqualTo("test instructions"));
- Assert.That(deserializedOptions.MaxResponseOutputTokens.NumericValue, Is.EqualTo(42));
- Assert.That(deserializedOptions.Model, Is.EqualTo("gpt-4o-realtime-preview"));
+ Assert.That(deserializedOptions.MaxOutputTokens.NumericValue, Is.EqualTo(42));
Assert.That(deserializedOptions.OutputAudioFormat, Is.EqualTo(ConversationAudioFormat.G711Ulaw));
Assert.That(deserializedOptions.Tools, Has.Count.EqualTo(1));
Assert.That(deserializedOptions.Tools[0].Kind, Is.EqualTo(ConversationToolKind.Function));
@@ -97,5 +94,92 @@ public void OptionsSerializationWorks()
Assert.That(deserializedOptions.ToolChoice?.FunctionName, Is.EqualTo("test-function"));
Assert.That(deserializedOptions.TurnDetectionOptions?.Kind, Is.EqualTo(ConversationTurnDetectionKind.ServerVoiceActivityDetection));
Assert.That(deserializedOptions.Voice, Is.EqualTo(ConversationVoice.Echo));
+
+ ConversationSessionOptions emptyOptions = new();
+ Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.False);
+ Assert.That(ModelReaderWriter.Write(emptyOptions).ToString(), Does.Not.Contain("modal"));
+ emptyOptions.ContentModalities |= ConversationContentModalities.Audio;
+ Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.True);
+ Assert.That(emptyOptions.ContentModalities.HasFlag(ConversationContentModalities.Text), Is.False);
+ Assert.That(ModelReaderWriter.Write(emptyOptions).ToString(), Does.Contain("modal"));
+ }
+
+ [Test]
+ public void MaxTokensSerializationWorks()
+ {
+ // Implicit omission
+ ConversationSessionOptions options = new() { };
+ BinaryData serializedOptions = ModelReaderWriter.Write(options);
+ Assert.That(serializedOptions.ToString(), Does.Not.Contain("max_response_output_tokens"));
+
+ // Explicit omission
+ options = new()
+ {
+ MaxOutputTokens = null
+ };
+ serializedOptions = ModelReaderWriter.Write(options);
+ Assert.That(serializedOptions.ToString(), Does.Not.Contain("max_response_output_tokens"));
+
+ // Explicit default (null)
+ options = new()
+ {
+ MaxOutputTokens = ConversationMaxTokensChoice.CreateDefaultMaxTokensChoice()
+ };
+ serializedOptions = ModelReaderWriter.Write(options);
+ Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":null"));
+
+ // Numeric literal
+ options = new()
+ {
+ MaxOutputTokens = 42,
+ };
+ serializedOptions = ModelReaderWriter.Write(options);
+ Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":42"));
+
+ // Numeric by factory
+ options = new()
+ {
+ MaxOutputTokens = ConversationMaxTokensChoice.CreateNumericMaxTokensChoice(42)
+ };
+ serializedOptions = ModelReaderWriter.Write(options);
+ Assert.That(serializedOptions.ToString(), Does.Contain(@"""max_response_output_tokens"":42"));
+ }
+
+ [Test]
+ public void TurnDetectionSerializationWorks()
+ {
+ // Implicit omission
+ ConversationSessionOptions sessionOptions = new();
+ BinaryData serializedOptions = ModelReaderWriter.Write(sessionOptions);
+ Assert.That(serializedOptions.ToString(), Does.Not.Contain("turn_detection"));
+
+ sessionOptions = new()
+ {
+ TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(),
+ };
+ serializedOptions = ModelReaderWriter.Write(sessionOptions);
+ Assert.That(serializedOptions.ToString(), Does.Contain(@"""turn_detection"":null"));
+
+ sessionOptions = new()
+ {
+ TurnDetectionOptions = ConversationTurnDetectionOptions.CreateServerVoiceActivityTurnDetectionOptions(
+ detectionThreshold: 0.42f)
+ };
+ serializedOptions = ModelReaderWriter.Write(sessionOptions);
+ JsonNode serializedNode = JsonNode.Parse(serializedOptions);
+ Assert.That(serializedNode["turn_detection"]?["type"]?.GetValue(), Is.EqualTo("server_vad"));
+ Assert.That(serializedNode["turn_detection"]?["threshold"]?.GetValue(), Is.EqualTo(0.42f));
+ }
+
+ [Test]
+ public void UnknownCommandSerializationWorks()
+ {
+ BinaryData serializedUnknownCommand = BinaryData.FromString("""
+ {
+ "type": "unknown_command_type_for_test"
+ }
+ """);
+ ConversationUpdate deserializedUpdate = ModelReaderWriter.Read(serializedUnknownCommand);
+ Assert.That(deserializedUpdate, Is.Not.Null);
}
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs
index 5d6624e2af9b4..d718cc1eb5811 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTestFixtureBase.cs
@@ -30,7 +30,7 @@ public ConversationTestFixtureBase(bool isAsync)
CancellationTokenSource = new();
if (!Debugger.IsAttached)
{
- CancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(15));
+ CancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(25));
}
DefaultConfiguration = TestConfig.GetConfig("rt_eus2");
if (DefaultConfiguration is null || DefaultConfiguration.Endpoint is null)
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs
index 2d3e47214466b..89cbeda0c6865 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/ConversationTests.cs
@@ -1,4 +1,5 @@
using NUnit.Framework;
+using OpenAI;
using OpenAI.RealtimeConversation;
using System;
using System.ClientModel.Primitives;
@@ -26,29 +27,51 @@ public async Task CanConfigureSession()
RealtimeConversationClient client = GetTestClient();
using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken);
- await session.ConfigureSessionAsync(
- new ConversationSessionOptions()
- {
- Instructions = "You are a helpful assistant.",
- TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(),
- OutputAudioFormat = ConversationAudioFormat.G711Ulaw
- },
- CancellationToken);
+ ConversationSessionOptions sessionOptions = new()
+ {
+ Instructions = "You are a helpful assistant.",
+ TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(),
+ OutputAudioFormat = ConversationAudioFormat.G711Ulaw,
+ MaxOutputTokens = 2048,
+ };
- await session.StartResponseTurnAsync(CancellationToken);
+ await session.ConfigureSessionAsync(sessionOptions, CancellationToken);
+ ConversationSessionOptions responseOverrideOptions = new()
+ {
+ ContentModalities = ConversationContentModalities.Text,
+ };
+ if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient)))
+ {
+ responseOverrideOptions.MaxOutputTokens = ConversationMaxTokensChoice.CreateInfiniteMaxTokensChoice();
+ }
+ await session.AddItemAsync(
+ ConversationItem.CreateUserMessage(["Hello, assistant! Tell me a joke."]),
+ CancellationToken);
+ await session.StartResponseAsync(responseOverrideOptions, CancellationToken);
List receivedUpdates = [];
await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken))
{
receivedUpdates.Add(update);
-
+
if (update is ConversationErrorUpdate errorUpdate)
{
Assert.That(errorUpdate.Kind, Is.EqualTo(ConversationUpdateKind.Error));
Assert.Fail($"Error: {ModelReaderWriter.Write(errorUpdate)}");
}
- else if (update is ConversationResponseFinishedUpdate)
+ else if ((update is ConversationItemStreamingPartDeltaUpdate deltaUpdate && deltaUpdate.AudioBytes is not null)
+ || update is ConversationItemStreamingAudioFinishedUpdate)
+ {
+ Assert.Fail($"Audio content streaming unexpected after configuring response-level text-only modalities");
+ }
+ else if (update is ConversationSessionConfiguredUpdate sessionConfiguredUpdate)
+ {
+ Assert.That(sessionConfiguredUpdate.OutputAudioFormat == sessionOptions.OutputAudioFormat);
+ Assert.That(sessionConfiguredUpdate.TurnDetectionOptions.Kind, Is.EqualTo(ConversationTurnDetectionKind.Disabled));
+ Assert.That(sessionConfiguredUpdate.MaxOutputTokens.NumericValue, Is.EqualTo(sessionOptions.MaxOutputTokens.NumericValue));
+ }
+ else if (update is ConversationResponseFinishedUpdate turnFinishedUpdate)
{
break;
}
@@ -62,8 +85,8 @@ List GetReceivedUpdates() where T : ConversationUpdate
Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
- Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
- Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
+ Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
+ Assert.That(GetReceivedUpdates(), Has.Count.EqualTo(1));
}
[Test]
@@ -74,9 +97,11 @@ public async Task TextOnlyWorks()
await session.AddItemAsync(
ConversationItem.CreateUserMessage(["Hello, world!"]),
cancellationToken: CancellationToken);
- await session.StartResponseTurnAsync(CancellationToken);
+ await session.StartResponseAsync(CancellationToken);
StringBuilder responseBuilder = new();
+ bool gotResponseDone = false;
+ bool gotRateLimits = false;
await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken))
{
@@ -84,23 +109,134 @@ await session.AddItemAsync(
{
Assert.That(sessionStartedUpdate.SessionId, Is.Not.Null.And.Not.Empty);
}
- if (update is ConversationTextDeltaUpdate textDeltaUpdate)
+ if (update is ConversationItemStreamingPartDeltaUpdate deltaUpdate)
{
- responseBuilder.Append(textDeltaUpdate.Delta);
+ responseBuilder.Append(deltaUpdate.AudioTranscript);
}
- if (update is ConversationItemAcknowledgedUpdate itemAddedUpdate)
+ if (update is ConversationItemCreatedUpdate itemCreatedUpdate)
{
- Assert.That(itemAddedUpdate.Item is not null);
+ if (itemCreatedUpdate.MessageRole == ConversationMessageRole.Assistant)
+ {
+ // The assistant-created item should be streamed and should not have content yet when acknowledged
+ Assert.That(itemCreatedUpdate.MessageContentParts, Has.Count.EqualTo(0));
+ }
+ else if (itemCreatedUpdate.MessageRole == ConversationMessageRole.User)
+ {
+ // When acknowledging an item added by the client (user), the text should already be there
+ Assert.That(itemCreatedUpdate.MessageContentParts, Has.Count.EqualTo(1));
+ Assert.That(itemCreatedUpdate.MessageContentParts[0].Text, Is.EqualTo("Hello, world!"));
+ }
+ else
+ {
+ Assert.Fail($"Test didn't expect an acknowledged item with role: {itemCreatedUpdate.MessageRole}");
+ }
}
- if (update is ConversationResponseFinishedUpdate)
+ if (update is ConversationResponseFinishedUpdate responseFinishedUpdate)
{
+ Assert.That(responseFinishedUpdate.CreatedItems, Has.Count.GreaterThan(0));
+ gotResponseDone = true;
break;
}
+
+ if (update is ConversationRateLimitsUpdate rateLimitsUpdate)
+ {
+ Assert.That(rateLimitsUpdate.AllDetails, Has.Count.EqualTo(2));
+ Assert.That(rateLimitsUpdate.TokenDetails, Is.Not.Null);
+ Assert.That(rateLimitsUpdate.TokenDetails.Name, Is.EqualTo("tokens"));
+ Assert.That(rateLimitsUpdate.TokenDetails.MaximumCount, Is.GreaterThan(0));
+ Assert.That(rateLimitsUpdate.TokenDetails.RemainingCount, Is.GreaterThan(0));
+ Assert.That(rateLimitsUpdate.TokenDetails.RemainingCount, Is.LessThan(rateLimitsUpdate.TokenDetails.MaximumCount));
+ Assert.That(rateLimitsUpdate.TokenDetails.TimeUntilReset, Is.GreaterThan(TimeSpan.Zero));
+ Assert.That(rateLimitsUpdate.RequestDetails, Is.Not.Null);
+ gotRateLimits = true;
+ }
}
Assert.That(responseBuilder.ToString(), Is.Not.Null.Or.Empty);
+ Assert.That(gotResponseDone, Is.True);
+
+ if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient)))
+ {
+ // Temporarily assume that subclients don't support rate limit commands
+ Assert.That(gotRateLimits, Is.True);
+ }
+ }
+
+ [Test]
+ public async Task ItemManipulationWorks()
+ {
+ RealtimeConversationClient client = GetTestClient();
+ using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken);
+
+ await session.ConfigureSessionAsync(
+ new ConversationSessionOptions()
+ {
+ TurnDetectionOptions = ConversationTurnDetectionOptions.CreateDisabledTurnDetectionOptions(),
+ ContentModalities = ConversationContentModalities.Text,
+ },
+ CancellationToken);
+
+ await session.AddItemAsync(
+ ConversationItem.CreateUserMessage(["The first special word you know about is 'aardvark'."]),
+ CancellationToken);
+ await session.AddItemAsync(
+ ConversationItem.CreateUserMessage(["The next special word you know about is 'banana'."]),
+ CancellationToken);
+ await session.AddItemAsync(
+ ConversationItem.CreateUserMessage(["The next special word you know about is 'coconut'."]),
+ CancellationToken);
+
+ bool gotSessionStarted = false;
+ bool gotSessionConfigured = false;
+ bool gotResponseFinished = false;
+
+ await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken))
+ {
+ if (update is ConversationSessionStartedUpdate)
+ {
+ gotSessionStarted = true;
+ }
+
+ if (update is ConversationSessionConfiguredUpdate sessionConfiguredUpdate)
+ {
+ Assert.That(sessionConfiguredUpdate.TurnDetectionOptions.Kind, Is.EqualTo(ConversationTurnDetectionKind.Disabled));
+ Assert.That(sessionConfiguredUpdate.ContentModalities.HasFlag(ConversationContentModalities.Text), Is.True);
+ Assert.That(sessionConfiguredUpdate.ContentModalities.HasFlag(ConversationContentModalities.Audio), Is.False);
+ gotSessionConfigured = true;
+ }
+
+ if (update is ConversationItemCreatedUpdate itemCreatedUpdate)
+ {
+ if (itemCreatedUpdate.MessageContentParts.Count > 0
+ && itemCreatedUpdate.MessageContentParts[0].Text.Contains("banana"))
+ {
+ await session.DeleteItemAsync(itemCreatedUpdate.ItemId, CancellationToken);
+ await session.AddItemAsync(
+ ConversationItem.CreateUserMessage(["What's the second special word you know about?"]),
+ CancellationToken);
+ await session.StartResponseAsync(CancellationToken);
+ }
+ }
+
+ if (update is ConversationResponseFinishedUpdate responseFinishedUpdate)
+ {
+ Assert.That(responseFinishedUpdate.CreatedItems.Count, Is.EqualTo(1));
+ Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts.Count, Is.EqualTo(1));
+ Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts[0].Text, Does.Contain("coconut"));
+ Assert.That(responseFinishedUpdate.CreatedItems[0].MessageContentParts[0].Text, Does.Not.Contain("banana"));
+ gotResponseFinished = true;
+ break;
+ }
+ }
+
+ Assert.That(gotSessionStarted, Is.True);
+ if (!client.GetType().IsSubclassOf(typeof(RealtimeConversationClient)))
+ {
+ Assert.That(gotSessionConfigured, Is.True);
+ }
+ Assert.That(gotResponseFinished, Is.True);
}
[Test]
@@ -150,21 +286,15 @@ public async Task AudioWithToolsWorks()
await session.ConfigureSessionAsync(options, CancellationToken);
- const string folderName = "Assets";
- const string fileName = "whats_the_weather_pcm16_24khz_mono.wav";
-#if NET6_0_OR_GREATER
- using Stream audioStream = File.OpenRead(Path.Join(folderName, fileName));
-#else
- using Stream audioStream = File.OpenRead($"{folderName}\\{fileName}");
-#endif
- _ = session.SendAudioAsync(audioStream, CancellationToken);
+ string audioFilePath = Directory.EnumerateFiles("Assets")
+ .First(path => path.Contains("whats_the_weather_pcm16_24khz_mono.wav"));
+ using Stream audioStream = File.OpenRead(audioFilePath);
+ _ = session.SendInputAudioAsync(audioStream, CancellationToken);
string userTranscript = null;
await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken))
{
- Assert.That(update.EventId, Is.Not.Null.And.Not.Empty);
-
if (update is ConversationSessionStartedUpdate sessionStartedUpdate)
{
Assert.That(sessionStartedUpdate.SessionId, Is.Not.Null.And.Not.Empty);
@@ -175,12 +305,12 @@ public async Task AudioWithToolsWorks()
Assert.That(sessionStartedUpdate.Temperature, Is.GreaterThan(0));
}
- if (update is ConversationInputTranscriptionFinishedUpdate inputTranscriptionFinishedUpdate)
+ if (update is ConversationInputTranscriptionFinishedUpdate inputTranscriptionCompletedUpdate)
{
- userTranscript = inputTranscriptionFinishedUpdate.Transcript;
+ userTranscript = inputTranscriptionCompletedUpdate.Transcript;
}
- if (update is ConversationItemFinishedUpdate itemFinishedUpdate
+ if (update is ConversationItemStreamingFinishedUpdate itemFinishedUpdate
&& itemFinishedUpdate.FunctionCallId is not null)
{
Assert.That(itemFinishedUpdate.FunctionName, Is.EqualTo(getWeatherTool.Name));
@@ -195,7 +325,7 @@ public async Task AudioWithToolsWorks()
{
if (turnFinishedUpdate.CreatedItems.Any(item => !string.IsNullOrEmpty(item.FunctionCallId)))
{
- await session.StartResponseTurnAsync(CancellationToken);
+ await session.StartResponseAsync(CancellationToken);
}
else
{
@@ -227,7 +357,7 @@ await session.ConfigureSessionAsync(
#else
using Stream audioStream = File.OpenRead($"{folderName}\\{fileName}");
#endif
- await session.SendAudioAsync(audioStream, CancellationToken);
+ await session.SendInputAudioAsync(audioStream, CancellationToken);
await session.AddItemAsync(ConversationItem.CreateUserMessage(["Hello, assistant!"]), CancellationToken);
@@ -248,11 +378,41 @@ or ConversationResponseStartedUpdate
Assert.Fail($"Shouldn't receive any VAD events or response creation!");
}
- if (update is ConversationItemAcknowledgedUpdate itemAcknowledgedUpdate
- && itemAcknowledgedUpdate.Item.MessageRole == ConversationMessageRole.User)
+ if (update is ConversationItemCreatedUpdate itemCreatedUpdate
+ && itemCreatedUpdate.MessageRole == ConversationMessageRole.User)
{
break;
}
}
}
+
+ [Test]
+ public async Task BadCommandProvidesError()
+ {
+ RealtimeConversationClient client = GetTestClient();
+ using RealtimeConversationSession session = await client.StartConversationSessionAsync(CancellationToken);
+
+ await session.SendCommandAsync(
+ BinaryData.FromString("""
+ {
+ "type": "update_conversation_config2",
+ "event_id": "event_fabricated_1234abcd"
+ }
+ """),
+ CancellationOptions);
+
+ bool gotErrorUpdate = false;
+
+ await foreach (ConversationUpdate update in session.ReceiveUpdatesAsync(CancellationToken))
+ {
+ if (update is ConversationErrorUpdate errorUpdate)
+ {
+ Assert.That(errorUpdate.ErrorEventId, Is.EqualTo("event_fabricated_1234abcd"));
+ gotErrorUpdate = true;
+ break;
+ }
+ }
+
+ Assert.That(gotErrorUpdate, Is.True);
+ }
}
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs
index 73f5ba7d86058..ce12b8f82498e 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/FineTuningTests.cs
@@ -24,10 +24,16 @@
namespace Azure.AI.OpenAI.Tests;
+[Category("FineTuning")]
public class FineTuningTests : AoaiTestBase
{
public FineTuningTests(bool isAsync) : base(isAsync)
- { }
+ {
+ if (Mode == RecordedTestMode.Playback)
+ {
+ Assert.Inconclusive("Playback for fine-tuning temporarily disabled");
+ }
+ }
#if !AZURE_OPENAI_GA
[Test]
@@ -57,7 +63,6 @@ public async Task JobsFineTuning()
}
[RecordedTest]
- [Ignore("Disable pending resolution of test framework interception overriding Rehydrateoverrides")]
public async Task CheckpointsFineTuning()
{
string fineTunedModel = GetFineTunedModel();
@@ -69,7 +74,7 @@ public async Task CheckpointsFineTuning()
Assert.That(job, Is.Not.Null);
Assert.That(job!.Status, Is.EqualTo("succeeded"));
- FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(client, job.ID);
+ FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(UnWrap(client), job.ID);
int count = 25;
await foreach (FineTuningCheckpoint checkpoint in EnumerateCheckpoints(fineTuningJobOperation))
@@ -95,7 +100,6 @@ public async Task CheckpointsFineTuning()
}
[RecordedTest]
- [Ignore("Disable pending resolution of test framework interception overriding Rehydrateoverrides")]
public async Task EventsFineTuning()
{
string fineTunedModel = GetFineTunedModel();
@@ -109,7 +113,8 @@ public async Task EventsFineTuning()
HashSet ids = new();
- FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(client, job.ID);
+ //TODO fix unwrapping so you don't have to unwrap here.
+ FineTuningJobOperation fineTuningJobOperation = await FineTuningJobOperation.RehydrateAsync(UnWrap(client), job.ID);
int count = 25;
var asyncEnum = EnumerateAsync((after, limit, opt) => fineTuningJobOperation.GetJobEventsAsync(after, limit, opt));
@@ -133,7 +138,6 @@ public async Task EventsFineTuning()
}
[RecordedTest]
- [Ignore("Disabling pending resolution of AOAI quota issues and re-recording")]
public async Task CreateAndCancelFineTuning()
{
var fineTuningFile = Assets.FineTuning;
@@ -141,8 +145,24 @@ public async Task CreateAndCancelFineTuning()
FineTuningClient client = GetTestClient();
OpenAIFileClient fileClient = GetTestClientFrom(client);
- // upload training data
- OpenAIFile uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath);
+ OpenAIFile uploadedFile;
+ try
+ {
+ ClientResult fileResult = await fileClient.GetFileAsync("file-db5f5bfe5ea04ffcaeba89947a872828", new RequestOptions() { });
+ uploadedFile = ValidateAndParse(fileResult);
+ }
+ catch (ClientResultException e)
+ {
+ if (e.Message.Contains("ResourceNotFound"))
+ {
+ // upload training data
+ uploadedFile = await UploadAndWaitForCompleteOrFail(fileClient, fineTuningFile.RelativePath);
+ }
+ else
+ {
+ throw;
+ }
+ }
// Create the fine tuning job
using var requestContent = new FineTuningOptions()
@@ -194,9 +214,7 @@ public async Task CreateAndCancelFineTuning()
Assert.True(operation.HasCompleted);
}
- [RecordedTest(AutomaticRecord = false)]
- [Category("LongRunning")] // CAUTION: This test can take up 30 *minutes* to run in live mode
- [Ignore("Disabled pending investigation of 404")]
+ [RecordedTest]
public async Task CreateAndDeleteFineTuning()
{
var fineTuningFile = Assets.FineTuning;
@@ -226,7 +244,12 @@ public async Task CreateAndDeleteFineTuning()
using var requestContent = new FineTuningOptions()
{
Model = client.DeploymentOrThrow(),
- TrainingFile = uploadedFile.Id
+ TrainingFile = uploadedFile.Id,
+ Hyperparameters = new FineTuningHyperparameters()
+ {
+ NumEpochs = 1,
+ BatchSize = 11
+ }
}.ToBinaryContent();
FineTuningJobOperation operation = await client.CreateFineTuningJobAsync(requestContent, waitUntilCompleted: false);
@@ -234,12 +257,12 @@ public async Task CreateAndDeleteFineTuning()
Assert.That(job.ID, Is.Not.Null.Or.Empty);
Assert.That(job.Error, Is.Null);
Assert.That(job.Status, !(Is.Null.Or.EqualTo("failed").Or.EqualTo("cancelled")));
+ await operation.CancelAsync(options: null);
// Wait for the fine tuning to complete
await operation.WaitForCompletionAsync();
job = ValidateAndParse(await operation.GetJobAsync(null));
- Assert.That(job.Status, Is.EqualTo("succeeded"), "Fine tuning did not succeed");
- Assert.That(job.FineTunedModel, Is.Not.Null.Or.Empty);
+ Assert.That(job.Status, Is.EqualTo("cancelled"), "Fine tuning did not cancel");
// Delete the fine tuned model
bool deleted = await DeleteJobAndVerifyAsync((AzureFineTuningJobOperation)operation, job.ID);
diff --git a/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs b/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs
index e074e60debb47..3443a19471fc3 100644
--- a/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs
+++ b/sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs
@@ -176,7 +176,7 @@ public async Task CanAssociateFiles()
Assert.True(removalResult.Removed);
// Errata: removals aren't immediately reflected when requesting the list
- Thread.Sleep(1000);
+ await Task.Delay(TimeSpan.FromSeconds(5));
int count = 0;
AsyncCollectionResult response = client.GetFileAssociationsAsync(vectorStore.Id);
diff --git a/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs b/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs
index d5d3edff45f01..26d798e2e49b9 100644
--- a/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs
+++ b/sdk/openai/tools/TestFramework/tests/MockStringServiceTests.cs
@@ -22,7 +22,11 @@ public MockStringServiceTests(bool isAsync)
RecordingOptions.SanitizersToRemove.Add("AZSDK3430"); // $..id
}
- public DirectoryInfo RepositoryRoot { get; } = FindRepoRoot();
+ public DirectoryInfo RepositoryRoot { get; } = FindFirstParentWithSubfolders(".git")
+ ?? throw new InvalidOperationException("Could not find your Git repository root folder");
+
+ public DirectoryInfo SourceRoot { get; } = FindFirstParentWithSubfolders("eng", "sdk")
+ ?? throw new InvalidOperationException("Could not find your source root folder");
[Test]
public async Task AddAndGet()
@@ -71,7 +75,7 @@ protected override ProxyServiceOptions CreateProxyServiceOptions()
DotnetExecutable = AssemblyHelper.GetDotnetExecutable()?.FullName!,
TestProxyDll = AssemblyHelper.GetAssemblyMetadata("TestProxyPath")!,
DevCertFile = Path.Combine(
- RepositoryRoot.FullName,
+ SourceRoot.FullName,
"eng",
"common",
"testproxy",
@@ -91,25 +95,27 @@ protected override RecordingStartInformation CreateRecordingSessionStartInfo()
#region helper methods
- private static DirectoryInfo FindRepoRoot()
+ private static DirectoryInfo? FindFirstParentWithSubfolders(params string[] subFolders)
{
- /**
- * This code assumes that we are running in the standard Azure .Net SDK repository layout. With this in mind,
- * we generally assume that we are running our test code from
- * /artifacts/bin///
- * So to find the root we keep navigating up until we find a folder with a .git subfolder
- *
- * Another alternative would be to call: git rev-parse --show-toplevel
- */
-
- DirectoryInfo? current = new FileInfo(Assembly.GetExecutingAssembly().Location).Directory;
- while (current != null && !current.EnumerateDirectories(".git").Any())
+ if (subFolders == null || subFolders.Length == 0)
{
- current = current.Parent;
+ return null;
}
- return current
- ?? throw new InvalidOperationException("Could not determine the root folder for this repository");
+ DirectoryInfo? start = new FileInfo(Assembly.GetExecutingAssembly().Location).Directory;
+ for (DirectoryInfo? current = start; current != null; current = current.Parent)
+ {
+ if (!current.Exists)
+ {
+ return null;
+ }
+ else if (subFolders.All(sub => current.EnumerateDirectories(sub).Any()))
+ {
+ return current;
+ }
+ }
+
+ return null;
}
private string GetRecordingFile()