From fc1a231f5e2d63297994cee57ef172f962a027a0 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Fri, 22 Nov 2024 14:23:30 -0600 Subject: [PATCH] `ManagedIdentityCredential` honors CancellationTokens (#47171) --- sdk/identity/Azure.Identity/CHANGELOG.md | 6 +- .../src/MsalManagedIdentityClient.cs | 4 +- .../tests/ImdsManagedIdentitySourceTests.cs | 59 +++++++++++++++++++ 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/sdk/identity/Azure.Identity/CHANGELOG.md b/sdk/identity/Azure.Identity/CHANGELOG.md index c8ab91606a862..5231e0ab93c53 100644 --- a/sdk/identity/Azure.Identity/CHANGELOG.md +++ b/sdk/identity/Azure.Identity/CHANGELOG.md @@ -7,8 +7,10 @@ ### Breaking Changes ### Bugs Fixed - - Fixed an issue where setting `DefaultAzureCredentialOptions.TenantId` twice throws an `InvalidOperationException`. ([#47035](https://github.com/Azure/azure-sdk-for-net/issues/47035)) - - Fixed an issue where some credentials in `DefaultAzureCredential` would not fall through to the next credential in the chain under certain exception conditions. + +- Fixed an issue where setting `DefaultAzureCredentialOptions.TenantId` twice throws an `InvalidOperationException` ([#47035](https://github.com/Azure/azure-sdk-for-net/issues/47035)) +- Fixed an issue where `ManagedIdentityCredential` does not honor the `CancellationToken` passed to `GetToken` and `GetTokenAsync`. ([#47156](https://github.com/Azure/azure-sdk-for-net/issues/47156)) +- Fixed an issue where some credentials in `DefaultAzureCredential` would not fall through to the next credential in the chain under certain exception conditions. ### Other Changes diff --git a/sdk/identity/Azure.Identity/src/MsalManagedIdentityClient.cs b/sdk/identity/Azure.Identity/src/MsalManagedIdentityClient.cs index 4ccb6383b5c18..750431c5067d3 100644 --- a/sdk/identity/Azure.Identity/src/MsalManagedIdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/MsalManagedIdentityClient.cs @@ -99,8 +99,8 @@ public virtual async ValueTask AcquireTokenForManagedIdent } #pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult(). return async ? - await builder.ExecuteAsync().ConfigureAwait(false) : - builder.ExecuteAsync().GetAwaiter().GetResult(); + await builder.ExecuteAsync(cancellationToken).ConfigureAwait(false) : + builder.ExecuteAsync(cancellationToken).GetAwaiter().GetResult(); #pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult(). } } diff --git a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs index 8882324a81b46..a869edf8058eb 100644 --- a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs +++ b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Azure.Core; using Azure.Core.Pipeline; @@ -174,6 +175,64 @@ public void ManagedIdentityCredentialUsesDefaultTimeoutAndRetries() CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); } + [Test] + public void ManagedIdentityCredentialRetryBehaviorIsOverriddenWithOptions() + { + int callCount = 0; + List networkTimeouts = new(); + + var mockTransport = MockTransport.FromMessageCallback(msg => + { + callCount++; + networkTimeouts.Add(msg.NetworkTimeout); + Assert.IsTrue(msg.Request.Headers.TryGetValue(ImdsManagedIdentitySource.metadataHeaderName, out _)); + return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json"); + }); + + var options = new TokenCredentialOptions() + { + Transport = mockTransport, + RetryPolicy = new RetryPolicy(1, DelayStrategy.CreateFixedDelayStrategy(TimeSpan.Zero)) + }; + options.Retry.MaxDelay = TimeSpan.Zero; + + var cred = new ManagedIdentityCredential( + "testCLientId", options); + + Assert.ThrowsAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }))); + + var expectedTimeouts = new TimeSpan?[] { null, null }; + CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); + } + + [Test] + public void ManagedIdentityCredentialRespectsCancellationToken() + { + int callCount = 0; + + var mockTransport = MockTransport.FromMessageCallback(msg => + { + Task.Delay(1000).GetAwaiter().GetResult(); + callCount++; + return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json"); + }); + + var options = new TokenCredentialOptions() { Transport = mockTransport }; + options.Retry.MaxDelay = TimeSpan.FromSeconds(1); + + var cred = new ManagedIdentityCredential( + "testCLientId", options); + + var cts = new CancellationTokenSource(); + cts.CancelAfter(TimeSpan.Zero); + var ex = Assert.CatchAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }), cts.Token)); + Assert.IsTrue(ex is TaskCanceledException || ex is OperationCanceledException, "Expected TaskCanceledException or OperationCanceledException but got " + ex.GetType().ToString()); + + // Default number of retries is 5, so we should just ensure we have less than that. + // Timing on some platforms makes this test somewhat non-deterministic, so we just ensure we have less than 2 calls. + Assert.Less(callCount, 2); + } + private MockResponse CreateMockResponse(int responseCode, string token) { var response = new MockResponse(responseCode);