Skip to content

Commit

Permalink
Update sync credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Jul 27, 2023
1 parent d99be44 commit 28f3fe7
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import Any
from typing import Any, Optional

from azure.core.credentials import AccessToken
from .chained import ChainedTokenCredential
Expand Down Expand Up @@ -63,24 +63,29 @@ def __init__(self, **kwargs: Any) -> None:
ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs),
)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str tenant_id: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
token = self._successful_credential.get_token(*scopes, **kwargs)
token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info(
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token

return super(AzureApplicationCredential, self).get_token(*scopes, **kwargs)
return super(AzureApplicationCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def close(self) -> None:
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
Expand All @@ -74,6 +76,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -82,7 +85,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
``response`` attribute.
"""
# pylint:disable=useless-super-delegation
return super(AuthorizationCodeCredential, self).get_token(*scopes, **kwargs)
return super(AuthorizationCodeCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]:
return self._client.get_cached_access_token(scopes, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureDeveloperCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -101,6 +103,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -117,7 +120,11 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
commandString = " --scope ".join(scopes)
command = COMMAND_LINE.format(commandString)
tenant = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
claims=claims,
**kwargs,
)
if tenant:
command += " --tenant-id " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -79,6 +81,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -91,7 +94,11 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
resource = _scopes_to_resource(*scopes)
command = COMMAND_LINE.format(resource)
tenant = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
claims=claims,
**kwargs,
)
if tenant:
command += " --tenant " + tenant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand All @@ -93,6 +95,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str tenant_id: optional tenant to include in the token request.
:keyword str claims: not used by this credential; any value provided will be ignored.
:return: An access token with the desired scopes.
:rtype: ~azure.core.credentials.AccessToken
Expand All @@ -103,7 +106,11 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
receive an access token
"""
tenant_id = resolve_tenant(
default_tenant=self.tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
default_tenant=self.tenant_id,
tenant_id=tenant_id,
additionally_allowed_tenants=self._additionally_allowed_tenants,
claims=claims,
**kwargs,
)
command_line = get_command_line(scopes, tenant_id)
output = run_command_line(command_line, self._process_timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def __exit__(self, *args):
if self._client:
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
if not scopes:
raise ValueError("'get_token' requires at least one scope")

Expand All @@ -117,7 +119,9 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
token = self._client.obtain_token_by_refresh_token(
scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs
)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __enter__(self):
def __exit__(self, *args):
self._client.__exit__(*args)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
if not scopes:
raise ValueError('"get_token" requires at least one scope')

Expand All @@ -64,7 +66,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")

return self._acquire_token_silent(*scopes, **kwargs)
return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _initialize(self):
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def __init__(
else:
super(InteractiveCredential, self).__init__(**kwargs)

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
"""Request an access token for `scopes`.
This method is called automatically by Azure SDK clients.
Expand All @@ -120,7 +122,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
For more information about scopes, see
https://learn.microsoft.com/azure/active-directory/develop/scopes-oidc.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
claims challenge following an authorization failure
:keyword str tenant_id: optional tenant to include in the token request.
:return: An access token with the desired scopes.
Expand All @@ -139,7 +141,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:

allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication)
try:
token = self._acquire_token_silent(*scopes, **kwargs)
token = self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
_LOGGER.info("%s.get_token succeeded", self.__class__.__name__)
return token
except Exception as ex: # pylint:disable=broad-except
Expand All @@ -156,7 +158,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
now = int(time.time())

try:
result = self._request_token(*scopes, **kwargs)
result = self._request_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)
if "access_token" not in result:
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
response = self._client.get_error_response(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def __exit__(self, *args):
def close(self) -> None:
self.__exit__()

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
def get_token(
self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
) -> AccessToken:
if not self._client:
raise CredentialUnavailableError(message=self.get_unavailable_message())
return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs)
return super(ManagedIdentityBase, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]:
# casting because mypy can't determine that these methods are called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def validate_scopes(*scopes, **_):
def test_authenticate_default_scopes(authority, expected_scope):
"""when given no scopes, authenticate should default to the ARM scope appropriate for the configured authority"""

def validate_scopes(*scopes):
def validate_scopes(*scopes, **_):
assert scopes == (expected_scope,)
return REQUEST_TOKEN_RESULT

Expand Down

0 comments on commit 28f3fe7

Please sign in to comment.