Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly pass options to ManagedIdentityCredential #34122

Merged
merged 4 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions sdk/identity/Azure.Identity/src/CredentialPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@ internal class CredentialPipeline

private CredentialPipeline(TokenCredentialOptions options)
{
AuthorityHost = options.AuthorityHost;

HttpPipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), Array.Empty<HttpPipelinePolicy>(), new CredentialResponseClassifier());

Diagnostics = new ClientDiagnostics(options);
}

public CredentialPipeline(Uri authorityHost, HttpPipeline httpPipeline, ClientDiagnostics diagnostics)
public CredentialPipeline(HttpPipeline httpPipeline, ClientDiagnostics diagnostics)
{
AuthorityHost = authorityHost;

HttpPipeline = httpPipeline;

Diagnostics = diagnostics;
}

Expand All @@ -38,8 +32,6 @@ public static CredentialPipeline GetInstance(TokenCredentialOptions options)
return options is null ? s_singleton.Value : new CredentialPipeline(options);
}

public Uri AuthorityHost { get; }

public HttpPipeline HttpPipeline { get; }

public ClientDiagnostics Diagnostics { get; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;
using System;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
Expand All @@ -25,6 +25,7 @@ public class DeviceCodeCredential : TokenCredential
internal AuthenticationRecord Record { get; private set; }
internal Func<DeviceCodeInfo, CancellationToken, Task> DeviceCodeCallback { get; }
internal CredentialPipeline Pipeline { get; }
internal string DefaultScope { get; }

private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to initiate the device code authentication.";
private const string NoDefaultScopeMessage = "Authenticating in this environment requires specifying a TokenRequestContext.";
Expand Down Expand Up @@ -66,8 +67,8 @@ public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> device
/// <param name="clientId">The client id of the application to which the users will authenticate</param>
/// <param name="options">The client options for the newly created DeviceCodeCredential</param>
[EditorBrowsable(EditorBrowsableState.Never)]
public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> deviceCodeCallback, string tenantId, string clientId, TokenCredentialOptions options = default)
: this(deviceCodeCallback, Validations.ValidateTenantId(tenantId, nameof(tenantId), allowNull:true), clientId, options, null)
public DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> deviceCodeCallback, string tenantId, string clientId, TokenCredentialOptions options = default)
: this(deviceCodeCallback, Validations.ValidateTenantId(tenantId, nameof(tenantId), allowNull: true), clientId, options, null)
{
}

Expand All @@ -80,18 +81,18 @@ internal DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> devi
{
Argument.AssertNotNull(clientId, nameof(clientId));
Argument.AssertNotNull(deviceCodeCallback, nameof(deviceCodeCallback));

_tenantId = tenantId;
ClientId = clientId;
DeviceCodeCallback = deviceCodeCallback;
DisableAutomaticAuthentication = (options as DeviceCodeCredentialOptions)?.DisableAutomaticAuthentication ?? false;
Record = (options as DeviceCodeCredentialOptions)?.AuthenticationRecord;
Pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(
Pipeline,
tenantId,
ClientId,
AzureAuthorityHosts.GetDeviceCodeRedirectUri(Pipeline.AuthorityHost).AbsoluteUri,
AzureAuthorityHosts.GetDeviceCodeRedirectUri(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault()).AbsoluteUri,
options);
AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
}
Expand All @@ -103,10 +104,13 @@ internal DeviceCodeCredential(Func<DeviceCodeInfo, CancellationToken, Task> devi
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return Authenticate(new TokenRequestContext(new string[] { defaultScope }), cancellationToken);
return Authenticate(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -116,10 +120,13 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account on future execution of credentials using the same persisted token cache.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;
using System;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Microsoft.Identity.Client;

namespace Azure.Identity
{
Expand All @@ -25,6 +25,7 @@ public class InteractiveBrowserCredential : TokenCredential
internal CredentialPipeline Pipeline { get; }
internal bool DisableAutomaticAuthentication { get; }
internal AuthenticationRecord Record { get; private set; }
internal string DefaultScope { get; }

private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to interactively authenticate.";
private const string NoDefaultScopeMessage = "Authenticating in this environment requires specifying a TokenRequestContext.";
Expand Down Expand Up @@ -83,6 +84,7 @@ internal InteractiveBrowserCredential(string tenantId, string clientId, TokenCre
Pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
LoginHint = (options as InteractiveBrowserCredentialOptions)?.LoginHint;
var redirectUrl = (options as InteractiveBrowserCredentialOptions)?.RedirectUri?.AbsoluteUri ?? Constants.DefaultRedirectUrl;
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(Pipeline, tenantId, clientId, redirectUrl, options);
AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
Record = (options as InteractiveBrowserCredentialOptions)?.AuthenticationRecord;
Expand All @@ -95,10 +97,12 @@ internal InteractiveBrowserCredential(string tenantId, string clientId, TokenCre
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);

return Authenticate(new TokenRequestContext(new[] { defaultScope }), cancellationToken);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}
return Authenticate(new TokenRequestContext(new[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -108,10 +112,12 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The result of the authentication request, containing the acquired <see cref="AccessToken"/>, and the <see cref="AuthenticationRecord"/> which can be used to silently authenticate the account.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(Pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected ManagedIdentityCredential()
/// </param>
/// <param name="options">Options to configure the management of the requests sent to the Azure Active Directory service.</param>
public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions options = null)
: this(clientId, CredentialPipeline.GetInstance(options))
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options), Options = options }))
{
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
}
Expand All @@ -59,21 +59,20 @@ public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions
/// </param>
/// <param name="options">Options to configure the management of the requests sent to the Azure Active Directory service.</param>
public ManagedIdentityCredential(ResourceIdentifier resourceId, TokenCredentialOptions options = null)
: this(
new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options) }))
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options), Options = options }))
{
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
_clientId = resourceId.ToString();
}

internal ManagedIdentityCredential(string clientId, CredentialPipeline pipeline, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { Pipeline = pipeline, ClientId = clientId, PreserveTransport = preserveTransport }))
internal ManagedIdentityCredential(string clientId, CredentialPipeline pipeline, TokenCredentialOptions options = null, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions { Pipeline = pipeline, ClientId = clientId, PreserveTransport = preserveTransport, Options = options }))
{
_clientId = clientId;
}

internal ManagedIdentityCredential(ResourceIdentifier resourceId, CredentialPipeline pipeline, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions{Pipeline = pipeline, ResourceIdentifier = resourceId, PreserveTransport = preserveTransport}))
internal ManagedIdentityCredential(ResourceIdentifier resourceId, CredentialPipeline pipeline, TokenCredentialOptions options, bool preserveTransport = false)
: this(new ManagedIdentityClient(new ManagedIdentityClientOptions{Pipeline = pipeline, ResourceIdentifier = resourceId, PreserveTransport = preserveTransport, Options = options }))
{
_clientId = resourceId.ToString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class UsernamePasswordCredential : TokenCredential
private readonly string _tenantId;
internal string[] AdditionallyAllowedTenantIds { get; }
internal MsalPublicClient Client { get; }
internal string DefaultScope { get; }

/// <summary>
/// Protected constructor for mocking
Expand Down Expand Up @@ -92,6 +93,7 @@ internal UsernamePasswordCredential(
_password = password;
_clientId = clientId;
_pipeline = pipeline ?? CredentialPipeline.GetInstance(options);
DefaultScope = AzureAuthorityHosts.GetDefaultScope(options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault());
Client = client ?? new MsalPublicClient(_pipeline, tenantId, clientId, null, options);

AdditionallyAllowedTenantIds = TenantIdResolver.ResolveAddionallyAllowedTenantIds(options?.AdditionallyAllowedTenantsCore);
Expand All @@ -104,10 +106,13 @@ internal UsernamePasswordCredential(
/// <returns>The <see cref="AuthenticationRecord"/> of the authenticated account.</returns>
public virtual AuthenticationRecord Authenticate(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(_pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return Authenticate(new TokenRequestContext(new string[] { defaultScope }), cancellationToken);
return Authenticate(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken);
}

/// <summary>
Expand All @@ -117,10 +122,13 @@ public virtual AuthenticationRecord Authenticate(CancellationToken cancellationT
/// <returns>The <see cref="AuthenticationRecord"/> of the authenticated account.</returns>
public virtual async Task<AuthenticationRecord> AuthenticateAsync(CancellationToken cancellationToken = default)
{
// get the default scope for the authority, throw if no default scope exists
string defaultScope = AzureAuthorityHosts.GetDefaultScope(_pipeline.AuthorityHost) ?? throw new CredentialUnavailableException(NoDefaultScopeMessage);
// throw if no default scope exists
if (DefaultScope == null)
{
throw new CredentialUnavailableException(NoDefaultScopeMessage);
}

return await AuthenticateAsync(new TokenRequestContext(new string[] { defaultScope }), cancellationToken).ConfigureAwait(false);
return await AuthenticateAsync(new TokenRequestContext(new string[] { DefaultScope }), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand Down
5 changes: 4 additions & 1 deletion sdk/identity/Azure.Identity/src/MsalClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ protected MsalClientBase(CredentialPipeline pipeline, string tenantId, string cl
// variable rather than in code. In this case we need to validate the endpoint before we use it. However, we can't validate in
// CredentialPipeline as this is also used by the ManagedIdentityCredential which allows non TLS endpoints. For this reason
// we validate here as all other credentials will create an MSAL client.
Validations.ValidateAuthorityHost(pipeline?.AuthorityHost);
Validations.ValidateAuthorityHost(options?.AuthorityHost);
AuthorityHost = options?.AuthorityHost ?? AzureAuthorityHosts.GetDefault();
_logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false;
DisableInstanceDiscovery = options is ISupportsDisableInstanceDiscovery supportsDisableInstanceDiscovery && supportsDisableInstanceDiscovery.DisableInstanceDiscovery;
ITokenCacheOptions cacheOptions = options as ITokenCacheOptions;
Expand All @@ -50,6 +51,8 @@ protected MsalClientBase(CredentialPipeline pipeline, string tenantId, string cl

internal TokenCache TokenCache { get; }

internal Uri AuthorityHost { get; }

protected internal CredentialPipeline Pipeline { get; }

protected abstract ValueTask<TClient> CreateClientAsync(bool async, CancellationToken cancellationToken);
Expand Down
Loading