Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Identity] Correctly implement TokenCredential protocols #31047

Merged
merged 21 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

### Bugs Fixed

- Credential types correctly implement `azure-core`'s `TokenCredential` protocol.
([#25175](https://github.com/Azure/azure-sdk-for-python/issues/25175))

### Other Changes

## 1.14.0b2 (2023-07-11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import Any
from typing import Any, Optional

from azure.core.credentials import AccessToken
from .chained import ChainedTokenCredential
Expand Down Expand Up @@ -63,24 +63,29 @@ def __init__(self, **kwargs: Any) -> None:
ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs),
)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not supported by this credential.
:keyword str tenant_id: not supported by this credential.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
token = self._successful_credential.get_token(*scopes, **kwargs)
token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info(
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token

return super(AzureApplicationCredential, self).get_token(*scopes, **kwargs)
return super(AzureApplicationCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def close(self) -> None:
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.
Expand All @@ -74,6 +76,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not supported by this credential.
mccoyp marked this conversation as resolved.
Show resolved Hide resolved

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -82,7 +85,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
``response`` attribute.
"""
# pylint:disable=useless-super-delegation
return super(AuthorizationCodeCredential, self).get_token(*scopes, **kwargs)
return super(AuthorizationCodeCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]:
return self._client.get_cached_access_token(scopes, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureDeveloperCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs # pylint:disable=unused-argument
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -101,6 +103,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.
pvaneck marked this conversation as resolved.
Show resolved Hide resolved

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -117,7 +120,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
commandString = " --scope ".join(scopes)
command = COMMAND_LINE.format(commandString)
tenant = resolve_tenant(
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
if tenant:
command += " --tenant-id " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs # pylint:disable=unused-argument
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -79,6 +81,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -91,7 +94,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
resource = _scopes_to_resource(*scopes)
command = COMMAND_LINE.format(resource)
tenant = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
if tenant:
command += " --tenant " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import subprocess
import sys
from typing import List, Tuple, Optional, Any
from typing import List, Tuple, Optional

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
Expand Down Expand Up @@ -83,7 +83,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs # pylint:disable=unused-argument
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -93,6 +95,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -103,7 +106,10 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
receive an access token
"""
tenant_id = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
**kwargs,
)
command_line = get_command_line(scopes, tenant_id)
output = run_command_line(command_line, self._process_timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@ def close(self) -> None:
"""Close the transport session of each credential in the chain."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request a token from each chained credential, in order, returning the first token received.

This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -86,7 +91,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disab
history = []
for credential in self.credentials:
try:
token = credential.get_token(*scopes, **kwargs)
token = credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__)
self._successful_credential = credential
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import List, TYPE_CHECKING, Any, cast
from typing import List, TYPE_CHECKING, Any, Optional, cast

from azure.core.credentials import AccessToken
from .._constants import EnvironmentVariables
Expand Down Expand Up @@ -195,14 +195,18 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement

super(DefaultAzureCredential, self).__init__(*credentials)

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: optional tenant to include in the token request.

:return: An access token with the desired scopes.
Expand All @@ -212,12 +216,12 @@ def get_token(self, *scopes: str, **kwargs) -> AccessToken:
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
token = self._successful_credential.get_token(*scopes, **kwargs)
token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info(
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token
within_dac.set(True)
token = super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
token = super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
within_dac.set(False)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def close(self) -> None:
self.__exit__()

@log_get_token("EnvironmentCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.
Expand All @@ -129,6 +131,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not supported by this credential.
mccoyp marked this conversation as resolved.
Show resolved Hide resolved

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -142,4 +145,4 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"this issue."
)
raise CredentialUnavailableError(message=message)
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def close(self) -> None:
self.__exit__()

@log_get_token("ManagedIdentityCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.
Expand All @@ -117,6 +119,9 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.

:keyword str claims: not supported by this credential.
:keyword str tenant_id: not supported by this credential.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment
Expand All @@ -129,4 +134,4 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to "
"troubleshoot this issue."
)
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def close(self) -> None:
self.__exit__()

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
"""Get an access token for `scopes` from the shared cache.

If no access token is cached, attempt to acquire one using a cached refresh token.
Expand All @@ -64,16 +66,18 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:keyword str tenant_id: not supported by this credential.
:keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested
token. Defaults to False.

:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason.
"""
return self._credential.get_token(*scopes, **kwargs)
return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

@staticmethod
def supported() -> bool:
Expand All @@ -97,7 +101,9 @@ def __exit__(self, *args):
if self._client:
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
if not scopes:
raise ValueError("'get_token' requires at least one scope")

Expand All @@ -123,7 +129,9 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
token = self._client.obtain_token_by_refresh_token(
scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs
)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __enter__(self):
def __exit__(self, *args):
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
) -> AccessToken:
if not scopes:
raise ValueError('"get_token" requires at least one scope')

Expand All @@ -70,7 +72,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)
return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]:

Expand Down
Loading