diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index ff6c37f2d7b6..3e0495d5aaaf 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -49,6 +49,8 @@ class DefaultAzureCredential(ChainedTokenCredential): :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. + :keyword bool exclude_workload_identity_credential: Whether to exclude the workload identity from the credential. + Defaults to **False**. :keyword bool exclude_azd_cli_credential: Whether to exclude the Azure Developer CLI from the credential. Defaults to **False**. :keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**. @@ -120,6 +122,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement developer_credential_timeout = kwargs.pop("developer_credential_timeout", 10) + exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False) exclude_environment_credential = kwargs.pop("exclude_environment_credential", False) exclude_managed_identity_credential = kwargs.pop("exclude_managed_identity_credential", False) exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False) @@ -132,15 +135,22 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement credentials: List["TokenCredential"] = [] if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, **kwargs)) - if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): - client_id = workload_identity_client_id - credentials.append(WorkloadIdentityCredential( - client_id=cast(str, client_id), - tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], - file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], - **kwargs)) + if not exclude_workload_identity_credential: + if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): + client_id = workload_identity_client_id + credentials.append(WorkloadIdentityCredential( + client_id=cast(str, client_id), + tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], + file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], + **kwargs)) if not exclude_managed_identity_credential: - credentials.append(ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs)) + credentials.append( + ManagedIdentityCredential( + client_id=managed_identity_client_id, + _exclude_workload_identity_credential=exclude_workload_identity_credential, + **kwargs + ) + ) if not exclude_azd_cli_credential: credentials.append(AzureDeveloperCliCredential(process_timeout=developer_credential_timeout)) if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported(): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 4cc6c7b81c64..6391fa71779b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -35,6 +35,7 @@ class ManagedIdentityCredential: def __init__(self, **kwargs: Any) -> None: self._credential = None # type: Optional[TokenCredential] + exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): if os.environ.get(EnvironmentVariables.IDENTITY_HEADER): if os.environ.get(EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT): @@ -63,7 +64,8 @@ def __init__(self, **kwargs: Any) -> None: from .cloud_shell import CloudShellCredential self._credential = CloudShellCredential(**kwargs) - elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): + elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)\ + and not exclude_workload_identity: _LOGGER.info("%s will use workload identity", self.__class__.__name__) from .workload_identity import WorkloadIdentityCredential diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 6eae972eee4c..1026c93744bf 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -48,6 +48,8 @@ class DefaultAzureCredential(ChainedTokenCredential): :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud. + :keyword bool exclude_workload_identity_credential: Whether to exclude the workload identity from the credential. + Defaults to **False**. :keyword bool exclude_azd_cli_credential: Whether to exclude the Azure Developer CLI from the credential. Defaults to **False**. :keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**. @@ -111,6 +113,7 @@ def __init__(self, **kwargs: Any) -> None: developer_credential_timeout = kwargs.pop("developer_credential_timeout", 10) + exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False) exclude_visual_studio_code_credential = kwargs.pop("exclude_visual_studio_code_credential", True) exclude_azd_cli_credential = kwargs.pop("exclude_azd_cli_credential", False) exclude_cli_credential = kwargs.pop("exclude_cli_credential", False) @@ -122,15 +125,22 @@ def __init__(self, **kwargs: Any) -> None: credentials = [] # type: List[AsyncTokenCredential] if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, **kwargs)) - if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): - client_id = workload_identity_client_id - credentials.append(WorkloadIdentityCredential( - client_id=cast(str, client_id), - tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], - file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], - **kwargs)) + if not exclude_workload_identity_credential: + if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): + client_id = workload_identity_client_id + credentials.append(WorkloadIdentityCredential( + client_id=cast(str, client_id), + tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], + file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], + **kwargs)) if not exclude_managed_identity_credential: - credentials.append(ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs)) + credentials.append( + ManagedIdentityCredential( + client_id=managed_identity_client_id, + _exclude_workload_identity_credential=exclude_workload_identity_credential, + **kwargs + ) + ) if not exclude_azd_cli_credential: credentials.append(AzureDeveloperCliCredential(process_timeout=developer_credential_timeout)) if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported(): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 5a520c71b14e..c0ff8b346480 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -36,6 +36,7 @@ class ManagedIdentityCredential(AsyncContextManager): def __init__(self, **kwargs: Any) -> None: self._credential = None # type: Optional[AsyncTokenCredential] + exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): if os.environ.get(EnvironmentVariables.IDENTITY_HEADER): @@ -70,7 +71,8 @@ def __init__(self, **kwargs: Any) -> None: from .cloud_shell import CloudShellCredential self._credential = CloudShellCredential(**kwargs) - elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS): + elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS) \ + and not exclude_workload_identity: _LOGGER.info("%s will use workload identity", self.__class__.__name__) from .workload_identity import WorkloadIdentityCredential diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 3441e37a324a..bdf46d487506 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -285,7 +285,7 @@ def test_default_credential_shared_cache_use(mock_credential): def test_managed_identity_client_id(): """the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable""" - expected_args = {"client_id": "the-client"} + expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False} with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential: DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"]) diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index e43c3100ff42..63abf636c215 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -245,7 +245,7 @@ async def test_default_credential_shared_cache_use(): def test_managed_identity_client_id(): """the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable""" - expected_args = {"client_id": "the client"} + expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False} with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential: DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])