Skip to content

Commit

Permalink
Properly pass options to ManagedIdentityCredential (#34122)
Browse files Browse the repository at this point in the history
* Properly pass options to ManagedIdentityCredential
  • Loading branch information
christothes authored Feb 15, 2023
1 parent ce8c527 commit 0b7acde
Show file tree
Hide file tree
Showing 19 changed files with 101 additions and 99 deletions.
10 changes: 1 addition & 9 deletions sdk/identity/Azure.Identity/src/CredentialPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@ internal class CredentialPipeline

private CredentialPipeline(TokenCredentialOptions options)
{
AuthorityHost = options.AuthorityHost;

HttpPipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), Array.Empty<HttpPipelinePolicy>(), new CredentialResponseClassifier());

Diagnostics = new ClientDiagnostics(options);
}

public CredentialPipeline(Uri authorityHost, HttpPipeline httpPipeline, ClientDiagnostics diagnostics)
public CredentialPipeline(HttpPipeline httpPipeline, ClientDiagnostics diagnostics)
{
AuthorityHost = authorityHost;

HttpPipeline = httpPipeline;

Diagnostics = diagnostics;
}

Expand All @@ -38,8 +32,6 @@ public static CredentialPipeline GetInstance(TokenCredentialOptions options)
return options is null ? s_singleton.Value : new CredentialPipeline(options);
}

public Uri AuthorityHost { get; }

public HttpPipeline HttpPipeline { get; }

public ClientDiagnostics Diagnostics { get; }
Expand Down
33 changes: 20 additions & 13 deletions sdk/identity/Azure.Identity/src/Credentials/DeviceCodeCredential.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;
using System;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
Expand All @@ -25,6 +25,7 @@ public class DeviceCodeCredential : TokenCredential
internal AuthenticationRecord Record { get; private set; }
internal Func<DeviceCodeInfo, CancellationToken, Task> DeviceCodeCallback { get; }
internal CredentialPipeline Pipeline { get; }
internal string DefaultScope { get; }

private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to initiate the device code authentication.";
private const string NoDefaultScopeMessage = "Authenticating in this environment requires specifying a TokenRequestContext.";
Expand Down Expand Up @@ -66,8 +67,8 @@ public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> device
/// <param name="clientId">The client id of the application to which the users will authenticate</param>
/// <param name="options">The client options for the newly created DeviceCodeCredential</param>
[EditorBrowsable(EditorBrowsableState.Never)]
public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> deviceCodeCallback, string tenantId, string clientId, TokenCredentialOptions options = default)
: this(deviceCodeCallback, Validations.ValidateTenantId(tenantId, nameof(tenantId), allowNull:true), clientId, options, null)
public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> deviceCodeCallback, string tenantId, string clientId, TokenCredentialOptions options = default)
: this(deviceCodeCallback, Validations.ValidateTenantId(tenantId, nameof(tenantId), allowNull: true), clientId, options, null)
{
}

Expand All @@ -80,18 +81,18 @@ internal DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> devi
{
Argument.AssertNotNull(clientId, nameof(clientId));
Argument.AssertNotNull(deviceCodeCallback, nameof(deviceCodeCallback));

_tenantId = tenantId;
ClientId = clientId;
DeviceCodeCallback = deviceCodeCallback;
DisableAutomaticAuthentication = (options as DeviceCodeCredentialOptions)?.DisableAutomaticAuthentication ?? false;
Record = (options as DeviceCodeCredentialOptions)?.AuthenticationRecord;
Pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(
Pipeline,
tenantId,
ClientId,
AzureAuthorityHosts.GetDeviceCodeRedirectUri(Pipeline.AuthorityHost).AbsoluteUri,
AzureAuthorityHosts.GetDeviceCodeRedirectUri(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault()).AbsoluteUri,
options);
AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
}
Expand All @@ -103,10 +104,13 @@ internal DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> devi
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return Authenticate(new TokenRequestContext(new string[] { defaultScope }), cancellationToken);
return Authenticate(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -116,10 +120,13 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account on future execution of credentials using the same persisted token cache.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;
using System;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
Expand All @@ -25,6 +25,7 @@ public class InteractiveBrowserCredential : TokenCredential
internal CredentialPipeline Pipeline { get; }
internal bool DisableAutomaticAuthentication { get; }
internal AuthenticationRecord Record { get; private set; }
internal string DefaultScope { get; }

private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to interactively authenticate.";
private const string NoDefaultScopeMessage = "Authenticating in this environment requires specifying a TokenRequestContext.";
Expand Down Expand Up @@ -83,6 +84,7 @@ internal InteractiveBrowserCredential(string tenantId, string clientId, TokenCre
Pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
LoginHint = (options as InteractiveBrowserCredentialOptions)?.LoginHint;
var redirectUrl = (options as InteractiveBrowserCredentialOptions)?.RedirectUri?.AbsoluteUri ?? Constants.DefaultRedirectUrl;
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(Pipeline, tenantId, clientId, redirectUrl, options);
AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
Record = (options as InteractiveBrowserCredentialOptions)?.AuthenticationRecord;
Expand All @@ -95,10 +97,12 @@ internal InteractiveBrowserCredential(string tenantId, string clientId, TokenCre
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);

return Authenticate(new TokenRequestContext(new[] { defaultScope }), cancellationToken);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}
return Authenticate(new TokenRequestContext(new[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -108,10 +112,12 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected ManagedIdentityCredential()
/// </param>
/// <param name="options">Options to configure the management of the requests sent to the Azure Active Directory service.</param>
public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions options = null)
: this(clientId, CredentialPipeline.GetInstance(options))
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options), Options = options }))
{
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
}
Expand All @@ -59,21 +59,20 @@ public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions
/// </param>
/// <param name="options">Options to configure the management of the requests sent to the Azure Active Directory service.</param>
public ManagedIdentityCredential(ResourceIdentifier resourceId, TokenCredentialOptions options = null)
: this(
new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options) }))
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options), Options = options }))
{
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
_clientId = resourceId.ToString();
}

internal ManagedIdentityCredential(string clientId, CredentialPipeline pipeline, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { Pipeline = pipeline, ClientId = clientId, PreserveTransport = preserveTransport }))
internal ManagedIdentityCredential(string clientId, CredentialPipeline pipeline, TokenCredentialOptions options = null, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { Pipeline = pipeline, ClientId = clientId, PreserveTransport = preserveTransport, Options = options }))
{
_clientId = clientId;
}

internal ManagedIdentityCredential(ResourceIdentifier resourceId, CredentialPipeline pipeline, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions{Pipeline = pipeline, ResourceIdentifier = resourceId, PreserveTransport = preserveTransport}))
internal ManagedIdentityCredential(ResourceIdentifier resourceId, CredentialPipeline pipeline, TokenCredentialOptions options, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions{Pipeline = pipeline, ResourceIdentifier = resourceId, PreserveTransport = preserveTransport, Options = options }))
{
_clientId = resourceId.ToString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class UsernamePasswordCredential : TokenCredential
private readonly string _tenantId;
internal string[] AdditionallyAllowedTenantIds { get; }
internal MsalPublicClient Client { get; }
internal string DefaultScope { get; }

/// <summary>
/// Protected constructor for mocking
Expand Down Expand Up @@ -92,6 +93,7 @@ internal UsernamePasswordCredential(
_password = password;
_clientId = clientId;
_pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(_pipeline, tenantId, clientId, null, options);

AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
Expand All @@ -104,10 +106,13 @@ internal UsernamePasswordCredential(
/// <returns>The <see cref="AuthenticationRecord"/> of the authenticated account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(_pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return Authenticate(new TokenRequestContext(new string[] { defaultScope }), cancellationToken);
return Authenticate(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -117,10 +122,13 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The <see cref="AuthenticationRecord"/> of the authenticated account.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(_pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
5 changes: 4 additions & 1 deletion sdk/identity/Azure.Identity/src/MsalClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ protected MsalClientBase(CredentialPipeline pipeline, string tenantId, string cl
// variable rather than in code. In this case we need to validate the endpoint before we use it. However, we can't validate in
// CredentialPipeline as this is also used by the ManagedIdentityCredential which allows non TLS endpoints. For this reason
// we validate here as all other credentials will create an MSAL client.
Validations.ValidateAuthorityHost(pipeline?.AuthorityHost);
Validations.ValidateAuthorityHost(options?.AuthorityHost);
AuthorityHost = options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault();
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
DisableInstanceDiscovery = options is ISupportsDisableInstanceDiscovery supportsDisableInstanceDiscovery && supportsDisableInstanceDiscovery.DisableInstanceDiscovery;
ITokenCacheOptions cacheOptions = options as ITokenCacheOptions;
Expand All @@ -50,6 +51,8 @@ protected MsalClientBase(CredentialPipeline pipeline, string tenantId, string cl

internal TokenCache TokenCache { get; }

internal Uri AuthorityHost { get; }

protected internal CredentialPipeline Pipeline { get; }

protected abstract ValueTask<TClient> CreateClientAsync(bool async, CancellationToken cancellationToken);
Expand Down
Loading

0 comments on commit 0b7acde

Please sign in to comment.