diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 4e226bc0c3577..c8603b662a6d0 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING import msal -from six.moves.urllib_parse import urlparse +import six from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -57,16 +57,27 @@ def _build_auth_record(response): # MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info home_account_id = id_token["sub"] + # "iss" is the URL of the issuing tenant e.g. https://authority/tenant + issuer = six.moves.urllib_parse.urlparse(id_token["iss"]) + + # tenant which issued the token, not necessarily user's home tenant + tenant_id = id_token.get("tid") or issuer.path.strip("/") + + # AAD returns "preferred_username", ADFS returns "upn" + username = id_token.get("preferred_username") or id_token["upn"] + return AuthenticationRecord( - authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant + authority=issuer.netloc, client_id=id_token["aud"], home_account_id=home_account_id, - tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant - username=id_token["preferred_username"], + tenant_id=tenant_id, + username=username, + ) + except (KeyError, ValueError) as ex: + auth_error = ClientAuthenticationError( + message="Failed to build AuthenticationRecord from unexpected identity token" ) - except (KeyError, ValueError): - # surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change - return None + six.raise_from(auth_error, ex) class InteractiveCredential(MsalCredential): diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 692c686df2327..1bac8aaceeddb 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -14,35 +14,33 @@ import mock # type: ignore -# build_* lifted from msal tests def build_id_token( iss="issuer", sub="subject", - aud="my_client_id", + aud="client-id", username="username", - tenant_id="tenant id", - object_id="object id", - exp=None, - iat=None, + tenant_id="tenant-id", + object_id="object-id", **claims -): # AAD issues "preferred_username", ADFS issues "upn" - return "header.%s.signature" % base64.b64encode( - json.dumps( - dict( - { - "iss": iss, - "sub": sub, - "aud": aud, - "exp": exp or (time.time() + 100), - "iat": iat or time.time(), - "tid": tenant_id, - "oid": object_id, - "preferred_username": username, - }, - **claims - ) - ).encode() - ).decode("utf-8") +): + token_claims = id_token_claims( + iss=iss, sub=sub, aud=aud, tid=tenant_id, oid=object_id, preferred_username=username, **claims + ) + jwt_payload = base64.b64encode(json.dumps(token_claims).encode()).decode("utf-8") + return "header.{}.signature".format(jwt_payload) + + +def build_adfs_id_token(iss="issuer", sub="subject", aud="client-id", username="username", **claims): + token_claims = id_token_claims(iss=iss, sub=sub, aud=aud, upn=username, **claims) + jwt_payload = base64.b64encode(json.dumps(token_claims).encode()).decode("utf-8") + return "header.{}.signature".format(jwt_payload) + + +def id_token_claims(iss, sub, aud, exp=None, iat=None, **claims): + return dict( + {"iss": iss, "sub": sub, "aud": aud, "exp": exp or int(time.time()) + 3600, "iat": iat or int(time.time())}, + **claims + ) def build_aad_response( # simulate a response from AAD diff --git a/sdk/identity/azure-identity/tests/test_browser_credential.py b/sdk/identity/azure-identity/tests/test_browser_credential.py index e07c69f303964..25ea77f71b662 100644 --- a/sdk/identity/azure-identity/tests/test_browser_credential.py +++ b/sdk/identity/azure-identity/tests/test_browser_credential.py @@ -111,10 +111,13 @@ def test_disable_automatic_authentication(): @patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True) def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) - + client_id = "client-id" transport = validating_transport( requests=[Request()] * 2, - responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))], + responses=[ + get_discovery_response(), + mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))), + ], ) # mock local server fakes successful authentication by immediately returning a well-formed response @@ -123,7 +126,7 @@ def test_policies_configurable(): server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) credential = InteractiveBrowserCredential( - policies=[policy], transport=transport, server_class=server_class, _cache=TokenCache() + policies=[policy], client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache() ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): @@ -134,9 +137,13 @@ def test_policies_configurable(): @patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True) def test_user_agent(): + client_id = "client-id" transport = validating_transport( requests=[Request(), Request(required_headers={"User-Agent": USER_AGENT})], - responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))], + responses=[ + get_discovery_response(), + mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))), + ], ) # mock local server fakes successful authentication by immediately returning a well-formed response @@ -144,7 +151,9 @@ def test_user_agent(): auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) - credential = InteractiveBrowserCredential(transport=transport, server_class=server_class, _cache=TokenCache()) + credential = InteractiveBrowserCredential( + client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache() + ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope") @@ -284,7 +293,7 @@ def test_redirect_server(): thread.start() # send a request, verify the server exposes the query - url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec + url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec response = urllib.request.urlopen(url) # nosec assert response.code == 200 diff --git a/sdk/identity/azure-identity/tests/test_device_code_credential.py b/sdk/identity/azure-identity/tests/test_device_code_credential.py index f33553dcd6e3c..6428f7035345d 100644 --- a/sdk/identity/azure-identity/tests/test_device_code_credential.py +++ b/sdk/identity/azure-identity/tests/test_device_code_credential.py @@ -93,7 +93,9 @@ def test_disable_automatic_authentication(): empty_cache = TokenCache() # empty cache makes silent auth impossible transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) - credential = DeviceCodeCredential("client-id", disable_automatic_authentication=True, transport=transport, _cache=empty_cache) + credential = DeviceCodeCredential( + "client-id", disable_automatic_authentication=True, transport=transport, _cache=empty_cache + ) with pytest.raises(AuthenticationRequiredError): credential.get_token("scope") @@ -102,6 +104,7 @@ def test_disable_automatic_authentication(): def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) + client_id = "client-id" transport = validating_transport( requests=[Request()] * 3, responses=[ @@ -115,12 +118,16 @@ def test_policies_configurable(): "expires_in": 42, } ), - mock_response(json_payload=dict(build_aad_response(access_token="**"), scope="scope")), + mock_response( + json_payload=dict( + build_aad_response(access_token="**", id_token=build_id_token(aud=client_id)), scope="scope" + ) + ), ], ) credential = DeviceCodeCredential( - client_id="client-id", prompt_callback=Mock(), policies=[policy], transport=transport, _cache=TokenCache() + client_id=client_id, prompt_callback=Mock(), policies=[policy], transport=transport, _cache=TokenCache() ) credential.get_token("scope") @@ -129,6 +136,7 @@ def test_policies_configurable(): def test_user_agent(): + client_id = "client-id" transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[ @@ -141,18 +149,23 @@ def test_user_agent(): "expires_in": 42, } ), - mock_response(json_payload=dict(build_aad_response(access_token="**"), scope="scope")), + mock_response( + json_payload=dict( + build_aad_response(access_token="**", id_token=build_id_token(aud=client_id)), scope="scope" + ) + ), ], ) credential = DeviceCodeCredential( - client_id="client-id", prompt_callback=Mock(), transport=transport, _cache=TokenCache() + client_id=client_id, prompt_callback=Mock(), transport=transport, _cache=TokenCache() ) credential.get_token("scope") def test_device_code_credential(): + client_id = "client-id" expected_token = "access-token" user_code = "user-code" verification_uri = "verification-uri" @@ -172,20 +185,26 @@ def test_device_code_credential(): } ), mock_response( - json_payload={ - "access_token": expected_token, - "expires_in": expires_in, - "scope": "scope", - "token_type": "Bearer", - "refresh_token": "_", - } + json_payload=dict( + build_aad_response( + access_token=expected_token, + expires_in=expires_in, + refresh_token="_", + id_token=build_id_token(aud=client_id), + ), + scope="scope", + ), ), ], ) callback = Mock() credential = DeviceCodeCredential( - client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False, _cache=TokenCache() + client_id=client_id, + prompt_callback=callback, + transport=transport, + instance_discovery=False, + _cache=TokenCache(), ) now = datetime.datetime.utcnow() diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index 645e74f21bd0a..a708b7c2c9fce 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -18,7 +18,21 @@ except ImportError: # python < 3.3 from mock import Mock, patch # type: ignore -from helpers import build_aad_response +from helpers import build_aad_response, build_id_token, id_token_claims + + +# fake object for tests which need to exercise request_token but don't care about its return value +REQUEST_TOKEN_RESULT = build_aad_response( + access_token="***", + id_token_claims=id_token_claims( + aud="...", + iss="http://localhost/tenant", + sub="subject", + preferred_username="...", + tenant_id="...", + object_id="...", + ), +) class MockCredential(InteractiveCredential): @@ -132,7 +146,7 @@ def test_scopes_round_trip(): def validate_scopes(*scopes, **_): assert scopes == (scope,) - return {"access_token": "**", "expires_in": 42} + return REQUEST_TOKEN_RESULT request_token = Mock(wraps=validate_scopes) credential = MockCredential(disable_automatic_authentication=True, request_token=request_token) @@ -158,7 +172,7 @@ def test_authenticate_default_scopes(authority, expected_scope): def validate_scopes(*scopes): assert scopes == (expected_scope,) - return {"access_token": "**", "expires_in": 42} + return REQUEST_TOKEN_RESULT request_token = Mock(wraps=validate_scopes) MockCredential(authority=authority, request_token=request_token).authenticate() @@ -176,7 +190,7 @@ def test_authenticate_unknown_cloud(): def test_authenticate_ignores_disable_automatic_authentication(option): """authenticate should prompt for authentication regardless of the credential's configuration""" - request_token = Mock(return_value={"access_token": "**", "expires_in": 42}) + request_token = Mock(return_value=REQUEST_TOKEN_RESULT) MockCredential(request_token=request_token, disable_automatic_authentication=option).authenticate() assert request_token.call_count == 1, "credential didn't begin interactive authentication" @@ -296,19 +310,22 @@ def _request_token(self, *_, **__): assert record.home_account_id == "{}.{}".format(object_id, home_tenant) -def test_home_account_id_no_client_info(): - """the credential should use the subject claim as home_account_id when MSAL doesn't provide client_info""" +def test_adfs(): + """the credential should be able to construct an AuthenticationRecord from an ADFS response returned by MSAL""" + authority = "localhost" subject = "subject" + tenant = "adfs" + username = "username" msal_response = build_aad_response(access_token="***", refresh_token="**") - msal_response["id_token_claims"] = { - "aud": "client-id", - "iss": "https://localhost", - "object_id": "some-guid", - "tid": "some-tenant", - "preferred_username": "me", - "sub": subject, - } + msal_response["id_token_claims"] = id_token_claims( + aud="client-id", + iss="https://{}/{}".format(authority, tenant), + sub=subject, + tenant_id=tenant, + object_id="object-id", + upn=username, + ) class TestCredential(InteractiveCredential): def __init__(self, **kwargs): @@ -318,4 +335,7 @@ def _request_token(self, *_, **__): return msal_response record = TestCredential().authenticate() + assert record.authority == authority assert record.home_account_id == subject + assert record.tenant_id == tenant + assert record.username == username diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index efb5bcc668af5..5d756ecface17 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -769,7 +769,7 @@ def get_account_event( uid=uid, utid=utid, refresh_token=refresh_token, - id_token=build_id_token(aud=client_id, preferred_username=username), + id_token=build_id_token(aud=client_id, username=username), foci="1", **kwargs ), diff --git a/sdk/identity/azure-identity/tests/test_username_password_credential.py b/sdk/identity/azure-identity/tests/test_username_password_credential.py index f82d251090b02..5e4349a6e6dff 100644 --- a/sdk/identity/azure-identity/tests/test_username_password_credential.py +++ b/sdk/identity/azure-identity/tests/test_username_password_credential.py @@ -35,7 +35,8 @@ def test_policies_configurable(): transport = validating_transport( requests=[Request()] * 3, - responses=[get_discovery_response()] * 2 + [mock_response(json_payload=build_aad_response(access_token="**"))], + responses=[get_discovery_response()] * 2 + + [mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token()))], ) credential = UsernamePasswordCredential("client-id", "username", "password", policies=[policy], transport=transport) @@ -47,7 +48,8 @@ def test_policies_configurable(): def test_user_agent(): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], - responses=[get_discovery_response()] * 2 + [mock_response(json_payload=build_aad_response(access_token="**"))], + responses=[get_discovery_response()] * 2 + + [mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token()))], ) credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) @@ -57,6 +59,7 @@ def test_user_agent(): def test_username_password_credential(): expected_token = "access-token" + client_id = "client-id" transport = validating_transport( requests=[Request()] * 3, # not validating requests because they're formed by MSAL responses=[ @@ -66,18 +69,13 @@ def test_username_password_credential(): mock_response(json_payload={}), # token request mock_response( - json_payload={ - "access_token": expected_token, - "expires_in": 42, - "token_type": "Bearer", - "ext_expires_in": 42, - } + json_payload=build_aad_response(access_token=expected_token, id_token=build_id_token(aud=client_id)) ), ], ) credential = UsernamePasswordCredential( - client_id="some-guid", + client_id=client_id, username="user@azure", password="secret_password", transport=transport,