Skip to content

Commit

Permalink
Build AuthenticationRecords from ADFS identity tokens (Azure#13341)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored and rakshith91 committed Sep 4, 2020
1 parent 535a8b2 commit b12d9fd
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import TYPE_CHECKING

import msal
from six.moves.urllib_parse import urlparse
import six
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

Expand Down Expand Up @@ -57,16 +57,27 @@ def _build_auth_record(response):
# MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info
home_account_id = id_token["sub"]

# "iss" is the URL of the issuing tenant e.g. https://authority/tenant
issuer = six.moves.urllib_parse.urlparse(id_token["iss"])

# tenant which issued the token, not necessarily user's home tenant
tenant_id = id_token.get("tid") or issuer.path.strip("/")

# AAD returns "preferred_username", ADFS returns "upn"
username = id_token.get("preferred_username") or id_token["upn"]

return AuthenticationRecord(
authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant
authority=issuer.netloc,
client_id=id_token["aud"],
home_account_id=home_account_id,
tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant
username=id_token["preferred_username"],
tenant_id=tenant_id,
username=username,
)
except (KeyError, ValueError) as ex:
auth_error = ClientAuthenticationError(
message="Failed to build AuthenticationRecord from unexpected identity token"
)
except (KeyError, ValueError):
# surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change
return None
six.raise_from(auth_error, ex)


class InteractiveCredential(MsalCredential):
Expand Down
46 changes: 22 additions & 24 deletions sdk/identity/azure-identity/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,33 @@
import mock # type: ignore


# build_* lifted from msal tests
def build_id_token(
iss="issuer",
sub="subject",
aud="my_client_id",
aud="client-id",
username="username",
tenant_id="tenant id",
object_id="object id",
exp=None,
iat=None,
tenant_id="tenant-id",
object_id="object-id",
**claims
): # AAD issues "preferred_username", ADFS issues "upn"
return "header.%s.signature" % base64.b64encode(
json.dumps(
dict(
{
"iss": iss,
"sub": sub,
"aud": aud,
"exp": exp or (time.time() + 100),
"iat": iat or time.time(),
"tid": tenant_id,
"oid": object_id,
"preferred_username": username,
},
**claims
)
).encode()
).decode("utf-8")
):
token_claims = id_token_claims(
iss=iss, sub=sub, aud=aud, tid=tenant_id, oid=object_id, preferred_username=username, **claims
)
jwt_payload = base64.b64encode(json.dumps(token_claims).encode()).decode("utf-8")
return "header.{}.signature".format(jwt_payload)


def build_adfs_id_token(iss="issuer", sub="subject", aud="client-id", username="username", **claims):
token_claims = id_token_claims(iss=iss, sub=sub, aud=aud, upn=username, **claims)
jwt_payload = base64.b64encode(json.dumps(token_claims).encode()).decode("utf-8")
return "header.{}.signature".format(jwt_payload)


def id_token_claims(iss, sub, aud, exp=None, iat=None, **claims):
return dict(
{"iss": iss, "sub": sub, "aud": aud, "exp": exp or int(time.time()) + 3600, "iat": iat or int(time.time())},
**claims
)


def build_aad_response( # simulate a response from AAD
Expand Down
21 changes: 15 additions & 6 deletions sdk/identity/azure-identity/tests/test_browser_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ def test_disable_automatic_authentication():
@patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True)
def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

client_id = "client-id"
transport = validating_transport(
requests=[Request()] * 2,
responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))],
responses=[
get_discovery_response(),
mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))),
],
)

# mock local server fakes successful authentication by immediately returning a well-formed response
Expand All @@ -123,7 +126,7 @@ def test_policies_configurable():
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
policies=[policy], transport=transport, server_class=server_class, _cache=TokenCache()
policies=[policy], client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
Expand All @@ -134,17 +137,23 @@ def test_policies_configurable():

@patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True)
def test_user_agent():
client_id = "client-id"
transport = validating_transport(
requests=[Request(), Request(required_headers={"User-Agent": USER_AGENT})],
responses=[get_discovery_response(), mock_response(json_payload=build_aad_response(access_token="**"))],
responses=[
get_discovery_response(),
mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))),
],
)

# mock local server fakes successful authentication by immediately returning a well-formed response
oauth_state = "oauth-state"
auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(transport=transport, server_class=server_class, _cache=TokenCache())
credential = InteractiveBrowserCredential(
client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
credential.get_token("scope")
Expand Down Expand Up @@ -284,7 +293,7 @@ def test_redirect_server():
thread.start()

# send a request, verify the server exposes the query
url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec
url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec
response = urllib.request.urlopen(url) # nosec

assert response.code == 200
Expand Down
45 changes: 32 additions & 13 deletions sdk/identity/azure-identity/tests/test_device_code_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def test_disable_automatic_authentication():

empty_cache = TokenCache() # empty cache makes silent auth impossible
transport = Mock(send=Mock(side_effect=Exception("no request should be sent")))
credential = DeviceCodeCredential("client-id", disable_automatic_authentication=True, transport=transport, _cache=empty_cache)
credential = DeviceCodeCredential(
"client-id", disable_automatic_authentication=True, transport=transport, _cache=empty_cache
)

with pytest.raises(AuthenticationRequiredError):
credential.get_token("scope")
Expand All @@ -102,6 +104,7 @@ def test_disable_automatic_authentication():
def test_policies_configurable():
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock())

client_id = "client-id"
transport = validating_transport(
requests=[Request()] * 3,
responses=[
Expand All @@ -115,12 +118,16 @@ def test_policies_configurable():
"expires_in": 42,
}
),
mock_response(json_payload=dict(build_aad_response(access_token="**"), scope="scope")),
mock_response(
json_payload=dict(
build_aad_response(access_token="**", id_token=build_id_token(aud=client_id)), scope="scope"
)
),
],
)

credential = DeviceCodeCredential(
client_id="client-id", prompt_callback=Mock(), policies=[policy], transport=transport, _cache=TokenCache()
client_id=client_id, prompt_callback=Mock(), policies=[policy], transport=transport, _cache=TokenCache()
)

credential.get_token("scope")
Expand All @@ -129,6 +136,7 @@ def test_policies_configurable():


def test_user_agent():
client_id = "client-id"
transport = validating_transport(
requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})],
responses=[
Expand All @@ -141,18 +149,23 @@ def test_user_agent():
"expires_in": 42,
}
),
mock_response(json_payload=dict(build_aad_response(access_token="**"), scope="scope")),
mock_response(
json_payload=dict(
build_aad_response(access_token="**", id_token=build_id_token(aud=client_id)), scope="scope"
)
),
],
)

credential = DeviceCodeCredential(
client_id="client-id", prompt_callback=Mock(), transport=transport, _cache=TokenCache()
client_id=client_id, prompt_callback=Mock(), transport=transport, _cache=TokenCache()
)

credential.get_token("scope")


def test_device_code_credential():
client_id = "client-id"
expected_token = "access-token"
user_code = "user-code"
verification_uri = "verification-uri"
Expand All @@ -172,20 +185,26 @@ def test_device_code_credential():
}
),
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": expires_in,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "_",
}
json_payload=dict(
build_aad_response(
access_token=expected_token,
expires_in=expires_in,
refresh_token="_",
id_token=build_id_token(aud=client_id),
),
scope="scope",
),
),
],
)

callback = Mock()
credential = DeviceCodeCredential(
client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False, _cache=TokenCache()
client_id=client_id,
prompt_callback=callback,
transport=transport,
instance_discovery=False,
_cache=TokenCache(),
)

now = datetime.datetime.utcnow()
Expand Down
48 changes: 34 additions & 14 deletions sdk/identity/azure-identity/tests/test_interactive_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore

from helpers import build_aad_response
from helpers import build_aad_response, build_id_token, id_token_claims


# fake object for tests which need to exercise request_token but don't care about its return value
REQUEST_TOKEN_RESULT = build_aad_response(
access_token="***",
id_token_claims=id_token_claims(
aud="...",
iss="http://localhost/tenant",
sub="subject",
preferred_username="...",
tenant_id="...",
object_id="...",
),
)


class MockCredential(InteractiveCredential):
Expand Down Expand Up @@ -132,7 +146,7 @@ def test_scopes_round_trip():

def validate_scopes(*scopes, **_):
assert scopes == (scope,)
return {"access_token": "**", "expires_in": 42}
return REQUEST_TOKEN_RESULT

request_token = Mock(wraps=validate_scopes)
credential = MockCredential(disable_automatic_authentication=True, request_token=request_token)
Expand All @@ -158,7 +172,7 @@ def test_authenticate_default_scopes(authority, expected_scope):

def validate_scopes(*scopes):
assert scopes == (expected_scope,)
return {"access_token": "**", "expires_in": 42}
return REQUEST_TOKEN_RESULT

request_token = Mock(wraps=validate_scopes)
MockCredential(authority=authority, request_token=request_token).authenticate()
Expand All @@ -176,7 +190,7 @@ def test_authenticate_unknown_cloud():
def test_authenticate_ignores_disable_automatic_authentication(option):
"""authenticate should prompt for authentication regardless of the credential's configuration"""

request_token = Mock(return_value={"access_token": "**", "expires_in": 42})
request_token = Mock(return_value=REQUEST_TOKEN_RESULT)
MockCredential(request_token=request_token, disable_automatic_authentication=option).authenticate()
assert request_token.call_count == 1, "credential didn't begin interactive authentication"

Expand Down Expand Up @@ -296,19 +310,22 @@ def _request_token(self, *_, **__):
assert record.home_account_id == "{}.{}".format(object_id, home_tenant)


def test_home_account_id_no_client_info():
"""the credential should use the subject claim as home_account_id when MSAL doesn't provide client_info"""
def test_adfs():
"""the credential should be able to construct an AuthenticationRecord from an ADFS response returned by MSAL"""

authority = "localhost"
subject = "subject"
tenant = "adfs"
username = "username"
msal_response = build_aad_response(access_token="***", refresh_token="**")
msal_response["id_token_claims"] = {
"aud": "client-id",
"iss": "https://localhost",
"object_id": "some-guid",
"tid": "some-tenant",
"preferred_username": "me",
"sub": subject,
}
msal_response["id_token_claims"] = id_token_claims(
aud="client-id",
iss="https://{}/{}".format(authority, tenant),
sub=subject,
tenant_id=tenant,
object_id="object-id",
upn=username,
)

class TestCredential(InteractiveCredential):
def __init__(self, **kwargs):
Expand All @@ -318,4 +335,7 @@ def _request_token(self, *_, **__):
return msal_response

record = TestCredential().authenticate()
assert record.authority == authority
assert record.home_account_id == subject
assert record.tenant_id == tenant
assert record.username == username
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def get_account_event(
uid=uid,
utid=utid,
refresh_token=refresh_token,
id_token=build_id_token(aud=client_id, preferred_username=username),
id_token=build_id_token(aud=client_id, username=username),
foci="1",
**kwargs
),
Expand Down
Loading

0 comments on commit b12d9fd

Please sign in to comment.