Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #14974: [BUG] DefaultAzureCredential improperly catches AuthenticationFailedException #15057

Merged
merged 4 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions sdk/identity/Azure.Identity/src/AuthenticationFailedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,5 @@ public AuthenticationFailedException(string message, Exception innerException)
: base(message, innerException)
{
}

internal static AuthenticationFailedException CreateAggregateException(string message, IList<Exception> exceptions)
{
// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(message);

bool allCredentialUnavailableException = true;
foreach (var exception in exceptions)
{
allCredentialUnavailableException &= exception is CredentialUnavailableException;
errorMsg.Append(Environment.NewLine).Append("- ").Append(exception.Message);
}

var innerException = exceptions.Count == 1
? exceptions[0]
: new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", exceptions);

// If all credentials have thrown CredentialUnavailableException, throw CredentialUnavailableException,
// otherwise throw AuthenticationFailedException
return allCredentialUnavailableException
? new CredentialUnavailableException(errorMsg.ToString(), innerException)
: new AuthenticationFailedException(errorMsg.ToString(), innerException);
}
}
}
13 changes: 6 additions & 7 deletions sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class ChainedTokenCredential : TokenCredential
{
private const string AggregateAllUnavailableErrorMessage = "The ChainedTokenCredential failed to retrieve a token from the included credentials.";

private const string AggregateCredentialFailedErrorMessage = "The ChainedTokenCredential failed due to an unhandled exception: ";
private const string AuthenticationFailedErrorMessage = "The ChainedTokenCredential failed due to an unhandled exception: ";

private readonly TokenCredential[] _sources;

Expand Down Expand Up @@ -77,7 +77,7 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
var groupScopeHandler = new ScopeGroupHandler(default);
try
{
List<Exception> exceptions = new List<Exception>();
List<CredentialUnavailableException> exceptions = new List<CredentialUnavailableException>();
foreach (TokenCredential source in _sources)
{
try
Expand All @@ -88,18 +88,17 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
groupScopeHandler.Dispose(default, default);
return token;
}
catch (AuthenticationFailedException e)
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
catch (Exception e)
AlexanderSher marked this conversation as resolved.
Show resolved Hide resolved
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(AggregateCredentialFailedErrorMessage + e.Message, exceptions);
throw new AuthenticationFailedException(AuthenticationFailedErrorMessage + e.Message, e);
}
}

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
throw CredentialUnavailableException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}
catch (Exception exception)
{
Expand Down
21 changes: 21 additions & 0 deletions sdk/identity/Azure.Identity/src/CredentialUnavailableException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Text;
using Azure.Core;

namespace Azure.Identity
Expand Down Expand Up @@ -31,5 +33,24 @@ public CredentialUnavailableException(string message, Exception innerException)
: base(message, innerException)
{
}

internal static CredentialUnavailableException CreateAggregateException(string message, IList<CredentialUnavailableException> exceptions)
{
if (exceptions.Count == 1)
{
return exceptions[0];
}

// Build the credential unavailable message, this code is only reachable if all credentials throw AuthenticationFailedException
StringBuilder errorMsg = new StringBuilder(message);

foreach (var exception in exceptions)
{
errorMsg.Append(Environment.NewLine).Append("- ").Append(exception.Message);
}

var innerException = new AggregateException("Multiple exceptions were encountered while attempting to authenticate.", exceptions);
return new CredentialUnavailableException(errorMsg.ToString(), innerException);
}
}
}
11 changes: 3 additions & 8 deletions sdk/identity/Azure.Identity/src/DefaultAzureCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre

private static async ValueTask<(AccessToken, TokenCredential)> GetTokenFromSourcesAsync(TokenCredential[] sources, TokenRequestContext requestContext, bool async, CancellationToken cancellationToken)
{
List<Exception> exceptions = new List<Exception>();
List<CredentialUnavailableException> exceptions = new List<CredentialUnavailableException>();

for (var i = 0; i < sources.Length && sources[i] != null; i++)
{
Expand All @@ -155,18 +155,13 @@ private static async ValueTask<AccessToken> GetTokenFromCredentialAsync(TokenCre

return (token, sources[i]);
}
catch (AuthenticationFailedException e)
catch (CredentialUnavailableException e)
{
exceptions.Add(e);
}
catch (Exception e) when (!(e is OperationCanceledException))
{
exceptions.Add(e);
throw AuthenticationFailedException.CreateAggregateException(UnhandledExceptionMessage + e.Message, exceptions);
}
schaabs marked this conversation as resolved.
Show resolved Hide resolved
}

throw AuthenticationFailedException.CreateAggregateException(DefaultExceptionMessage, exceptions);
throw CredentialUnavailableException.CreateAggregateException(DefaultExceptionMessage, exceptions);
}

private static TokenCredential[] GetDefaultAzureCredentialChain(DefaultAzureCredentialFactory factory, DefaultAzureCredentialOptions options)
Expand Down
25 changes: 19 additions & 6 deletions sdk/identity/Azure.Identity/src/VisualStudioCodeCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@ private async ValueTask<AccessToken> GetTokenImplAsync(TokenRequestContext reque
}

var cloudInstance = GetAzureCloudInstance(environmentName);
var storedCredentials = _vscAdapter.GetCredentials(CredentialsSection, environmentName);

if (!IsRefreshTokenString(storedCredentials))
{
throw new CredentialUnavailableException("Need to re-authenticate user in VSCode Azure Account.");
}
string storedCredentials = GetStoredCredentials(environmentName);

var result = await _client.AcquireTokenByRefreshToken(requestContext.Scopes, storedCredentials, cloudInstance, tenant, async, cancellationToken).ConfigureAwait(false);
return scope.Succeeded(new AccessToken(result.AccessToken, result.ExpiresOn));
Expand All @@ -89,6 +84,24 @@ private async ValueTask<AccessToken> GetTokenImplAsync(TokenRequestContext reque
}
}

private string GetStoredCredentials(string environmentName)
{
try
{
var storedCredentials = _vscAdapter.GetCredentials(CredentialsSection, environmentName);
if (!IsRefreshTokenString(storedCredentials))
{
throw new CredentialUnavailableException("Need to re-authenticate user in VSCode Azure Account.");
}

return storedCredentials;
}
catch (InvalidOperationException ex)
{
throw new CredentialUnavailableException("Stored credentials not found. Need to authenticate user in VSCode Azure Account.", ex);
}
}

private static bool IsRefreshTokenString(string str)
{
for (var index = 0; index < str.Length; index++)
Expand Down
41 changes: 26 additions & 15 deletions sdk/identity/Azure.Identity/src/VisualStudioCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ private async Task<AccessToken> RunProcessesAsync(List<ProcessStartInfo> process
{
exceptions.Add(new CredentialUnavailableException($"Process \"{processStartInfo.FileName}\" has non-json output: {output}.", exception));
}
catch (Exception exception)
catch (Exception exception) when (!(exception is OperationCanceledException))
{
exceptions.Add(exception);
exceptions.Add(new CredentialUnavailableException($"Process \"{processStartInfo.FileName}\" has failed with unexpected error: {exception.Message}.", exception));
}
}

Expand Down Expand Up @@ -192,24 +192,35 @@ private VisualStudioTokenProvider[] GetTokenProviders(string tokenProviderPath)
{
var content = GetTokenProviderContent(tokenProviderPath);

using JsonDocument document = JsonDocument.Parse(content);
try
{
using JsonDocument document = JsonDocument.Parse(content);

JsonElement providersElement = document.RootElement.GetProperty("TokenProviders");
JsonElement providersElement = document.RootElement.GetProperty("TokenProviders");

var providers = new VisualStudioTokenProvider[providersElement.GetArrayLength()];
for (int i = 0; i < providers.Length; i++)
{
JsonElement providerElement = providersElement[i];
var providers = new VisualStudioTokenProvider[providersElement.GetArrayLength()];
for (int i = 0; i < providers.Length; i++)
{
JsonElement providerElement = providersElement[i];

var path = providerElement.GetProperty("Path").GetString();
var preference = providerElement.GetProperty("Preference").GetInt32();
var arguments = GetStringArrayPropertyValue(providerElement, "Arguments");
var path = providerElement.GetProperty("Path").GetString();
var preference = providerElement.GetProperty("Preference").GetInt32();
var arguments = GetStringArrayPropertyValue(providerElement, "Arguments");

providers[i] = new VisualStudioTokenProvider(path, arguments, preference);
}
providers[i] = new VisualStudioTokenProvider(path, arguments, preference);
}

Array.Sort(providers);
return providers;
Array.Sort(providers);
return providers;
}
catch (JsonException exception)
{
throw new CredentialUnavailableException($"File found at \"{tokenProviderPath}\" isn't a valid JSON file", exception);
}
catch (Exception exception)
{
throw new CredentialUnavailableException($"JSON file found at \"{tokenProviderPath}\" has invalid schema.", exception);
}
}

private string GetTokenProviderContent(string tokenProviderPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void AuthenticateWithCliCredential_InvalidJsonOutput([Values("", "{}", "{
{
var testProcess = new TestProcess { Output = jsonContent };
AzureCliCredential credential = InstrumentClient(new AzureCliCredential(CredentialPipeline.GetInstance(null), new TestProcessService(testProcess)));
Assert.CatchAsync<AuthenticationFailedException>(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)));
Assert.ThrowsAsync<AuthenticationFailedException>(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default)));
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public async Task ChainedTokenCredential_UseVisualStudioCredential()
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore_keyring" })]
public async Task ChainedTokenCredential_UseVisualStudioCodeCredential()
{
var cloudName = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -89,7 +89,7 @@ public async Task ChainedTokenCredential_UseVisualStudioCodeCredential()
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore_keyring" })]
public async Task ChainedTokenCredential_UseVisualStudioCodeCredential_ParallelCalls()
{
var cloudName = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -209,7 +209,35 @@ public void ChainedTokenCredential_AllCredentialsHaveFailed_CredentialUnavailabl
}

[Test]
public void ChainedTokenCredential_AllCredentialsHaveFailed_AuthenticationFailedException()
[NonParallelizable]
public void ChainedTokenCredential_AllCredentialsHaveFailed_FirstAuthenticationFailedException()
{
using var endpoint = new TestEnvVar("MSI_ENDPOINT", "abc");

var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var fileSystem = new TestFileSystemService();
var processService = new TestProcessService(new TestProcess {Error = "Error"});

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId);
var vsCredential = new VisualStudioCredential(default, default, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, default, default, fileSystem, vscAdapter);
var azureCliCredential = new AzureCliCredential(CredentialPipeline.GetInstance(null), processService);

var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential, azureCliCredential));

List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;
using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
Assert.CatchAsync<AuthenticationFailedException>(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None));
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(1, scopes.Count);
Assert.AreEqual($"{nameof(ManagedIdentityCredential)}.{nameof(ManagedIdentityCredential.GetToken)}", scopes[0].Name);
}

[Test]
public void ChainedTokenCredential_AllCredentialsHaveFailed_LastAuthenticationFailedException()
{
var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var fileSystem = new TestFileSystemService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async Task DefaultAzureCredential_UseVisualStudioCredential()
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore_keyring" })]
public async Task DefaultAzureCredential_UseVisualStudioCodeCredential()
{
var options = Recording.InstrumentClientOptions(new DefaultAzureCredentialOptions
Expand Down Expand Up @@ -101,7 +101,7 @@ public async Task DefaultAzureCredential_UseVisualStudioCodeCredential()
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore_keyring" })]
public async Task DefaultAzureCredential_UseVisualStudioCodeCredential_ParallelCalls()
{
var options = Recording.InstrumentClientOptions(new DefaultAzureCredentialOptions
Expand Down Expand Up @@ -237,7 +237,37 @@ public void DefaultAzureCredential_AllCredentialsHaveFailed_CredentialUnavailabl
}

[Test]
public void DefaultAzureCredential_AllCredentialsHaveFailed_AuthenticationFailedException()
[NonParallelizable]
public void DefaultAzureCredential_AllCredentialsHaveFailed_FirstAuthenticationFailedException()
{
using var endpoint = new TestEnvVar("MSI_ENDPOINT", "abc");

var options = Recording.InstrumentClientOptions(new DefaultAzureCredentialOptions
{
ExcludeEnvironmentCredential = true,
ExcludeInteractiveBrowserCredential = true,
ExcludeSharedTokenCacheCredential = true,
});

var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var factory = new TestDefaultAzureCredentialFactory(options, new TestFileSystemService(), new TestProcessService(new TestProcess { Error = "Error" }), vscAdapter);
var credential = InstrumentClient(new DefaultAzureCredential(factory, options));

List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;

using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
Assert.CatchAsync<AuthenticationFailedException>(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None));
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(2, scopes.Count);
Assert.AreEqual($"{nameof(DefaultAzureCredential)}.{nameof(DefaultAzureCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(ManagedIdentityCredential)}.{nameof(ManagedIdentityCredential.GetToken)}", scopes[1].Name);
}

[Test]
public void DefaultAzureCredential_AllCredentialsHaveFailed_LastAuthenticationFailedException()
{
var options = Recording.InstrumentClientOptions(new DefaultAzureCredentialOptions
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading