Skip to content

Commit

Permalink
exclude workload identity from dac (#29728)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyan99 authored Apr 3, 2023
1 parent 8c171cd commit b0f7788
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 20 deletions.
26 changes: 18 additions & 8 deletions sdk/identity/azure-identity/azure/identity/_credentials/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class DefaultAzureCredential(ChainedTokenCredential):
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud.
:keyword bool exclude_workload_identity_credential: Whether to exclude the workload identity from the credential.
Defaults to **False**.
:keyword bool exclude_azd_cli_credential: Whether to exclude the Azure Developer CLI
from the credential. Defaults to **False**.
:keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**.
Expand Down Expand Up @@ -120,6 +122,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement

developer_credential_timeout = kwargs.pop("developer_credential_timeout", 10)

exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False)
exclude_environment_credential = kwargs.pop("exclude_environment_credential", False)
exclude_managed_identity_credential = kwargs.pop("exclude_managed_identity_credential", False)
exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False)
Expand All @@ -132,15 +135,22 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
credentials: List["TokenCredential"] = []
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, **kwargs))
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
client_id = workload_identity_client_id
credentials.append(WorkloadIdentityCredential(
client_id=cast(str, client_id),
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs))
if not exclude_workload_identity_credential:
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
client_id = workload_identity_client_id
credentials.append(WorkloadIdentityCredential(
client_id=cast(str, client_id),
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs))
if not exclude_managed_identity_credential:
credentials.append(ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs))
credentials.append(
ManagedIdentityCredential(
client_id=managed_identity_client_id,
_exclude_workload_identity_credential=exclude_workload_identity_credential,
**kwargs
)
)
if not exclude_azd_cli_credential:
credentials.append(AzureDeveloperCliCredential(process_timeout=developer_credential_timeout))
if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ManagedIdentityCredential:

def __init__(self, **kwargs: Any) -> None:
self._credential = None # type: Optional[TokenCredential]
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
if os.environ.get(EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT):
Expand Down Expand Up @@ -63,7 +64,8 @@ def __init__(self, **kwargs: Any) -> None:
from .cloud_shell import CloudShellCredential

self._credential = CloudShellCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)\
and not exclude_workload_identity:
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
from .workload_identity import WorkloadIdentityCredential

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class DefaultAzureCredential(ChainedTokenCredential):
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
defines authorities for other clouds. Managed identities ignore this because they reside in a single cloud.
:keyword bool exclude_workload_identity_credential: Whether to exclude the workload identity from the credential.
Defaults to **False**.
:keyword bool exclude_azd_cli_credential: Whether to exclude the Azure Developer CLI
from the credential. Defaults to **False**.
:keyword bool exclude_cli_credential: Whether to exclude the Azure CLI from the credential. Defaults to **False**.
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(self, **kwargs: Any) -> None:

developer_credential_timeout = kwargs.pop("developer_credential_timeout", 10)

exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False)
exclude_visual_studio_code_credential = kwargs.pop("exclude_visual_studio_code_credential", True)
exclude_azd_cli_credential = kwargs.pop("exclude_azd_cli_credential", False)
exclude_cli_credential = kwargs.pop("exclude_cli_credential", False)
Expand All @@ -122,15 +125,22 @@ def __init__(self, **kwargs: Any) -> None:
credentials = [] # type: List[AsyncTokenCredential]
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, **kwargs))
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
client_id = workload_identity_client_id
credentials.append(WorkloadIdentityCredential(
client_id=cast(str, client_id),
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs))
if not exclude_workload_identity_credential:
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
client_id = workload_identity_client_id
credentials.append(WorkloadIdentityCredential(
client_id=cast(str, client_id),
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs))
if not exclude_managed_identity_credential:
credentials.append(ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs))
credentials.append(
ManagedIdentityCredential(
client_id=managed_identity_client_id,
_exclude_workload_identity_credential=exclude_workload_identity_credential,
**kwargs
)
)
if not exclude_azd_cli_credential:
credentials.append(AzureDeveloperCliCredential(process_timeout=developer_credential_timeout))
if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ManagedIdentityCredential(AsyncContextManager):

def __init__(self, **kwargs: Any) -> None:
self._credential = None # type: Optional[AsyncTokenCredential]
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)

if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
Expand Down Expand Up @@ -70,7 +71,8 @@ def __init__(self, **kwargs: Any) -> None:
from .cloud_shell import CloudShellCredential

self._credential = CloudShellCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
elif all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS) \
and not exclude_workload_identity:
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
from .workload_identity import WorkloadIdentityCredential

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_default_credential_shared_cache_use(mock_credential):
def test_managed_identity_client_id():
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""

expected_args = {"client_id": "the-client"}
expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}

with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/test_default_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def test_default_credential_shared_cache_use():
def test_managed_identity_client_id():
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""

expected_args = {"client_id": "the client"}
expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}

with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])
Expand Down

0 comments on commit b0f7788

Please sign in to comment.