Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SharedTokenCacheCredential takes an optional AuthenticationRecord #11637

Merged
merged 5 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
the keyword argument `interactive_browser_tenant_id`, or set the environment
variable `AZURE_TENANT_ID`.
([#11548](https://github.com/Azure/azure-sdk-for-python/issues/11548))
- `SharedTokenCacheCredential` can be initialized with an `AuthenticationRecord`
provided by a user credential.
([#11448](https://github.com/Azure/azure-sdk-for-python/issues/11448))
- The user authentication API added to `DeviceCodeCredential` and
`InteractiveBrowserCredential` in 1.4.0b3 is available on
`UsernamePasswordCredential` as well.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Mapping
from azure.core.credentials import AccessToken
from typing import Any
from .._internal import AadClientBase


Expand All @@ -31,6 +30,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
"""

@wrap_exceptions
Expand Down Expand Up @@ -67,4 +68,4 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not from ..__auth_record?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a public class, and I don't want to depend on its internal location.


CacheItem = Mapping[str, str]

Expand Down Expand Up @@ -86,13 +87,22 @@ class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()

environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless 'tenant_id' specifies another
authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._tenant_id = self._auth_record.tenant_id
self._authority = self._auth_record.authority
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see we have normalize_authority(authority) in ctor of AuthenticationRecord. So maybe we need it here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accounts in MSAL's cache don't have "normalized" authorities, so neither does AuthenticationRecord. This credential's client will normalize the authority it's given.

self._username = self._auth_record.username
self._environment_aliases = frozenset((self._authority,))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect w/ or w/o auth_record, we should have same behavior here as

environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good eye. We don't need to parse the netloc out of the record's authority because the record's authority is already a netloc. As for why I'm ignoring potential aliases: the purpose of the record is to enable an application to use tokens it cached in prior authentications. An AuthenticationRecord and the account it represents will have identical values for authority because they're both results of the same authentication. An account having an alias of the record's authority was cached during a different authentication, and therefore doesn't match the record.

else:
authenticating_tenant = "organizations"
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)

cache = kwargs.pop("_cache", None) # for ease of testing

Expand All @@ -110,7 +120,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
if cache:
self._cache = cache
self._client = self._get_auth_client(
authority=self._authority, cache=cache, **kwargs
authority=self._authority, tenant_id=authenticating_tenant, cache=cache, **kwargs
) # type: Optional[AadClientBase]
else:
self._client = None
Expand Down Expand Up @@ -161,6 +171,14 @@ def _get_account(self, username=None, tenant_id=None):
# cache is empty or contains no refresh token -> user needs to sign in
raise CredentialUnavailableError(message=NO_ACCOUNTS)

if self._auth_record:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need try catch here? if it is an invalid auth_record, what's our expected behavior?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if it is invalid, we want to raise ValueError or CredentialUnavailableError or ignore it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean if self._auth_record.home_account_id raises AttributeError? In that case self._auth_record isn't an AuthenticationRecord. Handling that exception seems tantamount to an isinstance check.

for account in accounts:
if account.get("home_account_id") == self._auth_record.home_account_id:
return account
raise CredentialUnavailableError(
message="The cache contains no account matching the given AuthenticationRecord."
)

filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
if len(filtered_accounts) == 1:
return filtered_accounts[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
"""

async def __aenter__(self):
Expand Down Expand Up @@ -74,4 +76,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))

def _get_auth_client(self, **kwargs: "Any") -> "AadClientBase":
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)
112 changes: 111 additions & 1 deletion sdk/identity/azure-identity/tests/test_shared_cache_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# ------------------------------------
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential
from azure.identity import (
AuthenticationRecord,
CredentialUnavailableError,
SharedTokenCacheCredential,
)
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
KNOWN_ALIASES,
Expand Down Expand Up @@ -502,6 +506,112 @@ def test_authority_environment_variable():
assert token.token == expected_access_token


def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = credential.get_token("scope")
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id


def get_account_event(
username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.identity import CredentialUnavailableError, KnownAuthorities
from azure.identity import AuthenticationRecord, CredentialUnavailableError
from azure.identity.aio import SharedTokenCacheCredential
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.shared_token_cache import (
Expand Down Expand Up @@ -566,3 +566,113 @@ async def test_authority_environment_variable():
credential = SharedTokenCacheCredential(transport=transport, _cache=cache)
token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record_no_match():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")


@pytest.mark.asyncio
async def test_authentication_record():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(account)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token


@pytest.mark.asyncio
async def test_auth_record_multiple_accounts_for_username():
tenant_id = "tenant-id"
client_id = "client-id"
authority = "localhost"
object_id = "object-id"
home_account_id = object_id + "." + tenant_id
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

expected_access_token = "****"
expected_refresh_token = "**"
expected_account = get_account_event(
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
)
cache = populated_cache(
expected_account,
get_account_event( # this account matches all but the record's tenant
username,
object_id,
"different-" + tenant_id,
authority=authority,
client_id=client_id,
refresh_token="not-" + expected_refresh_token,
),
)

transport = async_validating_transport(
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)

token = await credential.get_token("scope")
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id