-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
61 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,27 +2,24 @@ | |
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# ------------------------------------ | ||
import functools | ||
import random | ||
import socket | ||
import threading | ||
import time | ||
|
||
from azure.core.exceptions import ClientAuthenticationError | ||
from azure.core.pipeline.policies import SansIOHTTPPolicy | ||
from azure.core.pipeline.transport import RequestsTransport | ||
from azure.identity import AuthenticationRequiredError, CredentialUnavailableError, InteractiveBrowserCredential | ||
from azure.identity._internal import AuthCodeRedirectServer | ||
from azure.identity._internal.user_agent import USER_AGENT | ||
from msal import TokenCache | ||
import pytest | ||
from six.moves import urllib, urllib_parse | ||
from six.moves import urllib | ||
|
||
from helpers import ( | ||
build_aad_response, | ||
build_id_token, | ||
get_discovery_response, | ||
mock_response, | ||
msal_validating_transport, | ||
Request, | ||
validating_transport, | ||
) | ||
|
@@ -36,6 +33,40 @@ | |
WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open" | ||
|
||
|
||
@pytest.mark.manual | ||
def test_browser_credential(): | ||
"""This test isn't recorded because security features of the implementation prevent replaying sessions""" | ||
|
||
transport = Mock(wraps=RequestsTransport()) | ||
credential = InteractiveBrowserCredential(transport=transport) | ||
scope = "https://vault.azure.net/.default" | ||
record = credential.authenticate(scopes=(scope,)) | ||
|
||
assert record.authority | ||
assert record.home_account_id | ||
assert record.tenant_id | ||
assert record.username | ||
|
||
# credential should have a cached access token for the scope used in authenticate | ||
with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): | ||
token = credential.get_token(scope) | ||
assert token.token | ||
|
||
credential = InteractiveBrowserCredential(transport=transport) | ||
token = credential.get_token(scope) | ||
assert token.token | ||
|
||
with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): | ||
second_token = credential.get_token(scope) | ||
assert second_token.token == token.token | ||
|
||
# every request should have the correct User-Agent | ||
for call in transport.send.call_args_list: | ||
args, _ = call | ||
request = args[0] | ||
assert request.headers["User-Agent"] == USER_AGENT | ||
|
||
|
||
def test_tenant_id_validation(): | ||
"""The credential should raise ValueError when given an invalid tenant_id""" | ||
|
||
|
@@ -56,58 +87,6 @@ def test_no_scopes(): | |
InteractiveBrowserCredential().get_token() | ||
|
||
|
||
def test_authenticate(): | ||
client_id = "client-id" | ||
environment = "localhost" | ||
issuer = "https://" + environment | ||
tenant_id = "some-tenant" | ||
authority = issuer + "/" + tenant_id | ||
|
||
access_token = "***" | ||
scope = "scope" | ||
|
||
# mock AAD response with id token | ||
object_id = "object-id" | ||
home_tenant = "home-tenant-id" | ||
username = "[email protected]" | ||
id_token = build_id_token(aud=client_id, iss=issuer, object_id=object_id, tenant_id=home_tenant, username=username) | ||
auth_response = build_aad_response( | ||
uid=object_id, utid=home_tenant, access_token=access_token, refresh_token="**", id_token=id_token | ||
) | ||
|
||
transport = validating_transport( | ||
requests=[Request(url_substring=issuer)] * 3, | ||
responses=[get_discovery_response(authority)] * 2 + [mock_response(json_payload=auth_response)], | ||
) | ||
|
||
# mock local server fakes successful authentication by immediately returning a well-formed response | ||
oauth_state = "state" | ||
auth_code_response = {"code": "authorization-code", "state": [oauth_state]} | ||
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) | ||
|
||
with patch(InteractiveBrowserCredential.__module__ + ".uuid.uuid4", lambda: oauth_state): | ||
with patch(WEBBROWSER_OPEN, lambda _: True): | ||
credential = InteractiveBrowserCredential( | ||
_cache=TokenCache(), | ||
authority=environment, | ||
client_id=client_id, | ||
_server_class=server_class, | ||
tenant_id=tenant_id, | ||
transport=transport, | ||
) | ||
record = credential.authenticate(scopes=(scope,)) | ||
|
||
assert record.authority == environment | ||
assert record.home_account_id == object_id + "." + home_tenant | ||
assert record.tenant_id == home_tenant | ||
assert record.username == username | ||
|
||
# credential should have a cached access token for the scope used in authenticate | ||
with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): | ||
token = credential.get_token(scope) | ||
assert token.token == access_token | ||
|
||
|
||
def test_disable_automatic_authentication(): | ||
"""When configured for strict silent auth, the credential should raise when silent auth fails""" | ||
|
||
|
@@ -122,142 +101,20 @@ def test_disable_automatic_authentication(): | |
credential.get_token("scope") | ||
|
||
|
||
@patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True) | ||
def test_policies_configurable(): | ||
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) | ||
client_id = "client-id" | ||
transport = validating_transport( | ||
requests=[Request()] * 2, | ||
responses=[ | ||
get_discovery_response(), | ||
mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))), | ||
], | ||
) | ||
# the policy raises an exception so this test can run without authenticating i.e. opening a browser | ||
expected_message = "test_policies_configurable" | ||
policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(side_effect=Exception(expected_message))) | ||
|
||
# mock local server fakes successful authentication by immediately returning a well-formed response | ||
oauth_state = "oauth-state" | ||
auth_code_response = {"code": "authorization-code", "state": [oauth_state]} | ||
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) | ||
credential = InteractiveBrowserCredential(policies=[policy]) | ||
|
||
credential = InteractiveBrowserCredential( | ||
policies=[policy], client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache() | ||
) | ||
|
||
with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): | ||
with pytest.raises(ClientAuthenticationError) as ex: | ||
credential.get_token("scope") | ||
|
||
assert expected_message in ex.value.message | ||
assert policy.on_request.called | ||
|
||
|
||
@patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True) | ||
def test_user_agent(): | ||
client_id = "client-id" | ||
transport = validating_transport( | ||
requests=[Request(), Request(required_headers={"User-Agent": USER_AGENT})], | ||
responses=[ | ||
get_discovery_response(), | ||
mock_response(json_payload=build_aad_response(access_token="**", id_token=build_id_token(aud=client_id))), | ||
], | ||
) | ||
|
||
# mock local server fakes successful authentication by immediately returning a well-formed response | ||
oauth_state = "oauth-state" | ||
auth_code_response = {"code": "authorization-code", "state": [oauth_state]} | ||
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) | ||
|
||
credential = InteractiveBrowserCredential( | ||
client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache() | ||
) | ||
|
||
with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): | ||
credential.get_token("scope") | ||
|
||
|
||
@patch("azure.identity._credentials.browser.webbrowser.open") | ||
@pytest.mark.parametrize("redirect_url", ("https://localhost:8042", None)) | ||
def test_interactive_credential(mock_open, redirect_url): | ||
mock_open.side_effect = _validate_auth_request_url | ||
oauth_state = "state" | ||
client_id = "client-id" | ||
expected_refresh_token = "refresh-token" | ||
expected_token = "access-token" | ||
expires_in = 3600 | ||
authority = "authority" | ||
tenant_id = "tenant-id" | ||
endpoint = "https://{}/{}".format(authority, tenant_id) | ||
|
||
transport = msal_validating_transport( | ||
endpoint="https://{}/{}".format(authority, tenant_id), | ||
requests=[Request(url_substring=endpoint)] | ||
+ [ | ||
Request( | ||
authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token} | ||
) | ||
], | ||
responses=[ | ||
mock_response( | ||
json_payload=build_aad_response( | ||
access_token=expected_token, | ||
expires_in=expires_in, | ||
refresh_token=expected_refresh_token, | ||
uid="uid", | ||
utid=tenant_id, | ||
id_token=build_id_token(aud=client_id, object_id="uid", tenant_id=tenant_id, iss=endpoint), | ||
token_type="Bearer", | ||
) | ||
), | ||
mock_response( | ||
json_payload=build_aad_response(access_token=expected_token, expires_in=expires_in, token_type="Bearer") | ||
), | ||
], | ||
) | ||
|
||
# mock local server fakes successful authentication by immediately returning a well-formed response | ||
auth_code_response = {"code": "authorization-code", "state": [oauth_state]} | ||
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) | ||
|
||
args = { | ||
"authority": authority, | ||
"tenant_id": tenant_id, | ||
"client_id": client_id, | ||
"transport": transport, | ||
"_cache": TokenCache(), | ||
"_server_class": server_class, | ||
} | ||
if redirect_url: # avoid passing redirect_url=None | ||
args["redirect_uri"] = redirect_url | ||
|
||
credential = InteractiveBrowserCredential(**args) | ||
|
||
# The credential's auth code request includes a uuid which must be included in the redirect. Patching to | ||
# set the uuid requires less code here than a proper mock server. | ||
with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): | ||
token = credential.get_token("scope") | ||
assert token.token == expected_token | ||
assert mock_open.call_count == 1 | ||
assert server_class.call_count == 1 | ||
|
||
if redirect_url: | ||
parsed = urllib_parse.urlparse(redirect_url) | ||
server_class.assert_called_once_with(parsed.hostname, parsed.port, timeout=ANY) | ||
|
||
# token should be cached, get_token shouldn't prompt again | ||
token = credential.get_token("scope") | ||
assert token.token == expected_token | ||
assert mock_open.call_count == 1 | ||
assert server_class.call_count == 1 | ||
|
||
# expired access token -> credential should use refresh token instead of prompting again | ||
now = time.time() | ||
with patch("time.time", lambda: now + expires_in): | ||
token = credential.get_token("scope") | ||
assert token.token == expected_token | ||
assert mock_open.call_count == 1 | ||
|
||
# ensure all expected requests were sent | ||
assert transport.send.call_count == 4 | ||
|
||
|
||
def test_timeout(): | ||
"""get_token should raise ClientAuthenticationError when the server times out without receiving a redirect""" | ||
|
||
|
@@ -312,18 +169,35 @@ def test_redirect_server(): | |
response = urllib.request.urlopen(url) # nosec | ||
|
||
assert response.code == 200 | ||
assert server.query_params[expected_param] == [expected_value] | ||
assert server.query_params[expected_param] == expected_value | ||
|
||
|
||
@patch("azure.identity._credentials.browser.webbrowser.open", lambda _: False) | ||
def test_no_browser(): | ||
transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2) | ||
credential = InteractiveBrowserCredential( | ||
client_id="client-id", _server_class=Mock(), transport=transport, _cache=TokenCache() | ||
) | ||
with pytest.raises(ClientAuthenticationError, match=r".*browser.*"): | ||
with patch(WEBBROWSER_OPEN, lambda _: False): | ||
credential.get_token("scope") | ||
|
||
|
||
def test_redirect_uri(): | ||
"""The credential should configure the redirect server to use a given redirect_uri""" | ||
|
||
expected_hostname = "localhost" | ||
expected_port = 42424 | ||
expected_message = "test_redirect_uri" | ||
server = Mock(side_effect=Exception(expected_message)) # exception prevents this test actually authenticating | ||
credential = InteractiveBrowserCredential( | ||
redirect_uri="htps://{}:{}".format(expected_hostname, expected_port), _server_class=server | ||
) | ||
with pytest.raises(ClientAuthenticationError) as ex: | ||
credential.get_token("scope") | ||
|
||
assert expected_message in ex.value.message | ||
server.assert_called_once_with(expected_hostname, expected_port, timeout=ANY) | ||
|
||
|
||
@pytest.mark.parametrize("redirect_uri", ("http://localhost", "host", "host:42")) | ||
def test_invalid_redirect_uri(redirect_uri): | ||
|
@@ -351,13 +225,3 @@ def test_cannot_bind_redirect_uri(): | |
credential.get_token("scope") | ||
|
||
server.assert_called_once_with("localhost", 42, timeout=ANY) | ||
|
||
|
||
def _validate_auth_request_url(url): | ||
parsed_url = urllib_parse.urlparse(url) | ||
params = urllib_parse.parse_qs(parsed_url.query) | ||
assert params.get("prompt") == ["select_account"], "Auth code request doesn't specify 'prompt=select_account'." | ||
|
||
# when used as a Mock's side_effect, this method's return value is the Mock's return value | ||
# (the real webbrowser.open returns a bool) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters