From d8f05b099831fa4d26d83ed66182bdc61475a6fc Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 21 May 2020 16:33:16 -0700 Subject: [PATCH 1/5] SharedTokenCacheCredential takes optional AuthenticationRecord --- .../identity/_credentials/shared_cache.py | 5 ++-- .../identity/_internal/shared_token_cache.py | 29 ++++++++++++++----- .../identity/aio/_credentials/shared_cache.py | 2 ++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index b748b700e5bf..8c56b8a71d99 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -14,8 +14,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Mapping - from azure.core.credentials import AccessToken + from typing import Any from .._internal import AadClientBase @@ -31,6 +30,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase): defines authorities for other clouds. :keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains tokens for multiple identities. + :keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as + :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential` """ @wrap_exceptions diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 6fba6c09b986..cbbeeb613bba 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -27,6 +27,7 @@ # pylint:disable=unused-import,ungrouped-imports from typing import Any, Iterable, List, Mapping, Optional from .._internal import AadClientBase + from azure.identity import AuthenticationRecord CacheItem = Mapping[str, str] @@ -86,13 +87,19 @@ class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - authority = kwargs.pop("authority", None) - self._authority = normalize_authority(authority) if authority else get_default_authority() - - environment = urlparse(self._authority).netloc - self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) - self._username = username - self._tenant_id = kwargs.pop("tenant_id", None) + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + if self._auth_record: + self._authority = self._auth_record.authority + self._username = self._auth_record.username + self._tenant_id = self._auth_record.tenant_id + self._environment_aliases = frozenset((self._authority,)) + else: + authority = kwargs.pop("authority", None) + self._authority = normalize_authority(authority) if authority else get_default_authority() + environment = urlparse(self._authority).netloc + self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) + self._username = username + self._tenant_id = kwargs.pop("tenant_id", None) cache = kwargs.pop("_cache", None) # for ease of testing @@ -161,6 +168,14 @@ def _get_account(self, username=None, tenant_id=None): # cache is empty or contains no refresh token -> user needs to sign in raise CredentialUnavailableError(message=NO_ACCOUNTS) + if self._auth_record: + for account in accounts: + if account.get("home_account_id") == self._auth_record.home_account_id: + return account + raise CredentialUnavailableError( + message="The cache contains no account matching the given AuthenticationRecord." + ) + filtered_accounts = _filtered_accounts(accounts, username, tenant_id) if len(filtered_accounts) == 1: return filtered_accounts[0] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index ec021eba460d..dd32c9bebcc5 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -29,6 +29,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase): defines authorities for other clouds. :keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains tokens for multiple identities. + :keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as + :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential` """ async def __aenter__(self): From dab341706af719ac3259d2ccbe8d11b3ab01ac5a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 21 May 2020 16:41:22 -0700 Subject: [PATCH 2/5] tests --- .../tests/test_shared_cache_credential.py | 98 ++++++++++++++++++- .../test_shared_cache_credential_async.py | 98 ++++++++++++++++++- 2 files changed, 194 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index daf5a73dd503..f563f69a21d5 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -4,7 +4,11 @@ # ------------------------------------ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential +from azure.identity import ( + AuthenticationRecord, + CredentialUnavailableError, + SharedTokenCacheCredential, +) from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( KNOWN_ALIASES, @@ -502,6 +506,98 @@ def test_authority_environment_variable(): assert token.token == expected_access_token +def test_authentication_record_empty_cache(): + record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) + + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + + +def test_authentication_record_no_match(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + cache = populated_cache( + get_account_event( + "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, + ), + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + + +def test_authentication_record(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache(account) + + transport = validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = credential.get_token("scope") + assert token.token == expected_access_token + + +def test_auth_record_multiple_accounts_for_username(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + expected_account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache( + expected_account, + get_account_event( # this account matches all but the record's tenant + username, + object_id, + "different-" + tenant_id, + authority=authority, + client_id=client_id, + refresh_token="not-" + expected_refresh_token, + ), + ) + + transport = validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = credential.get_token("scope") + assert token.token == expected_access_token + + def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None ): diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index c718eaf3481c..2b0f42f53ec1 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -7,7 +7,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import CredentialUnavailableError, KnownAuthorities +from azure.identity import AuthenticationRecord, CredentialUnavailableError from azure.identity.aio import SharedTokenCacheCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( @@ -566,3 +566,99 @@ async def test_authority_environment_variable(): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) token = await credential.get_token("scope") assert token.token == expected_access_token + + +@pytest.mark.asyncio +async def test_authentication_record_empty_cache(): + record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) + + with pytest.raises(CredentialUnavailableError): + await credential.get_token("scope") + + +@pytest.mark.asyncio +async def test_authentication_record_no_match(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + cache = populated_cache( + get_account_event( + "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, + ), + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + with pytest.raises(CredentialUnavailableError): + await credential.get_token("scope") + + +@pytest.mark.asyncio +async def test_authentication_record(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache(account) + + transport = async_validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = await credential.get_token("scope") + assert token.token == expected_access_token + + +@pytest.mark.asyncio +async def test_auth_record_multiple_accounts_for_username(): + tenant_id = "tenant-id" + client_id = "client-id" + authority = "localhost" + object_id = "object-id" + home_account_id = object_id + "." + tenant_id + username = "me" + record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) + + expected_access_token = "****" + expected_refresh_token = "**" + expected_account = get_account_event( + username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token + ) + cache = populated_cache( + expected_account, + get_account_event( # this account matches all but the record's tenant + username, + object_id, + "different-" + tenant_id, + authority=authority, + client_id=client_id, + refresh_token="not-" + expected_refresh_token, + ), + ) + + transport = async_validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + + token = await credential.get_token("scope") + assert token.token == expected_access_token From 6aca1709e43159866c9e2c3aa22ef8511501eeab Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 26 May 2020 14:56:03 -0700 Subject: [PATCH 3/5] update changelog --- sdk/identity/azure-identity/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 781fe4b105c3..6ce04863952d 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -17,6 +17,9 @@ the keyword argument `interactive_browser_tenant_id`, or set the environment variable `AZURE_TENANT_ID`. ([#11548](https://github.com/Azure/azure-sdk-for-python/issues/11548)) +- `SharedTokenCacheCredential` can be initialized with an `AuthenticationRecord` + provided by a user credential. + ([#11448](https://github.com/Azure/azure-sdk-for-python/issues/11448)) - The user authentication API added to `DeviceCodeCredential` and `InteractiveBrowserCredential` in 1.4.0b3 is available on `UsernamePasswordCredential` as well. From 5e7490483a79448c8f39dd7f46072ba945dbc8ef Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 28 May 2020 16:21:07 -0700 Subject: [PATCH 4/5] tenant_id kwarg specifies authenticating tenant --- .../azure/identity/_credentials/shared_cache.py | 2 +- .../azure/identity/_internal/shared_token_cache.py | 7 +++++-- .../identity/aio/_credentials/shared_cache.py | 2 +- .../tests/test_shared_cache_credential.py | 14 ++++++++++++++ .../tests/test_shared_cache_credential_async.py | 14 ++++++++++++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 8c56b8a71d99..f48ba99bdd56 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -68,4 +68,4 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument def _get_auth_client(self, **kwargs): # type: (**Any) -> AadClientBase - return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs) + return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index cbbeeb613bba..5942850553d3 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -89,11 +89,14 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] if self._auth_record: + # authenticate in the tenant that produced the record unless 'tenant_id' specifies another + authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + self._tenant_id = self._auth_record.tenant_id self._authority = self._auth_record.authority self._username = self._auth_record.username - self._tenant_id = self._auth_record.tenant_id self._environment_aliases = frozenset((self._authority,)) else: + authenticating_tenant = "common" authority = kwargs.pop("authority", None) self._authority = normalize_authority(authority) if authority else get_default_authority() environment = urlparse(self._authority).netloc @@ -117,7 +120,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument if cache: self._cache = cache self._client = self._get_auth_client( - authority=self._authority, cache=cache, **kwargs + authority=self._authority, tenant_id=authenticating_tenant, cache=cache, **kwargs ) # type: Optional[AadClientBase] else: self._client = None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index dd32c9bebcc5..c3ad6db261d9 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -76,4 +76,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) def _get_auth_client(self, **kwargs: "Any") -> "AadClientBase": - return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs) + return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index f563f69a21d5..9d8cba6ca757 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -598,6 +598,20 @@ def test_auth_record_multiple_accounts_for_username(): assert token.token == expected_access_token +def test_authentication_record_authenticating_tenant(): + """when given a record and 'tenant_id', the credential should authenticate in the latter""" + + expected_tenant_id = "tenant-id" + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") + + with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: + SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + + assert get_auth_client.call_count == 1 + _, kwargs = get_auth_client.call_args + assert kwargs["tenant_id"] == expected_tenant_id + + def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None ): diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 2b0f42f53ec1..4552b4e8c9f3 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -662,3 +662,17 @@ async def test_auth_record_multiple_accounts_for_username(): token = await credential.get_token("scope") assert token.token == expected_access_token + + +def test_authentication_record_authenticating_tenant(): + """when given a record and 'tenant_id', the credential should authenticate in the latter""" + + expected_tenant_id = "tenant-id" + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") + + with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: + SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + + assert get_auth_client.call_count == 1 + _, kwargs = get_auth_client.call_args + assert kwargs["tenant_id"] == expected_tenant_id From d6458ff8f4641e90190a4a6be4a121852fe7a10e Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 4 Jun 2020 09:12:37 -0700 Subject: [PATCH 5/5] prefer organizations tenant --- .../azure/identity/_internal/shared_token_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 5942850553d3..045f3e637072 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -96,7 +96,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument self._username = self._auth_record.username self._environment_aliases = frozenset((self._authority,)) else: - authenticating_tenant = "common" + authenticating_tenant = "organizations" authority = kwargs.pop("authority", None) self._authority = normalize_authority(authority) if authority else get_default_authority() environment = urlparse(self._authority).netloc