From d6c7c019b12a90fb9916f4b5ecf0ab6eb7616a4d Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 16 Mar 2020 21:36:03 +0800 Subject: [PATCH 1/6] Add EnvVar AZURE_AUTHORITY_HOST --- .../azure-identity/azure/identity/_authn_client.py | 7 +++++-- sdk/identity/azure-identity/azure/identity/_constants.py | 1 + .../azure-identity/azure/identity/_credentials/default.py | 3 ++- .../azure/identity/_internal/aad_client_base.py | 6 ++++-- .../azure/identity/_internal/msal_credentials.py | 6 ++++-- .../azure/identity/_internal/shared_token_cache.py | 8 +++++--- .../azure/identity/aio/_credentials/default.py | 3 ++- 7 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index 7b927c9d621a..bc55e0e66c8b 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -5,6 +5,7 @@ import abc import calendar import time +import os from msal import TokenCache @@ -22,7 +23,7 @@ UserAgentPolicy, ) from azure.core.pipeline.transport import RequestsTransport, HttpRequest -from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities +from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities, EnvironmentVariables from ._internal.user_agent import USER_AGENT try: @@ -61,7 +62,9 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl else: if not tenant: raise ValueError("'tenant' is required") - authority = authority or KnownAuthorities.AZURE_PUBLIC_CLOUD + if not authority: + authority = os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index ffb0ed644b58..373f4b2844b4 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -29,6 +29,7 @@ class EnvironmentVariables: MSI_ENDPOINT = "MSI_ENDPOINT" MSI_SECRET = "MSI_SECRET" + AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" class Endpoints: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 137261bb679b..6516c13ed3ac 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -61,7 +61,8 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD + authority = kwargs.pop("authority", None) or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index a7e30ba90ab4..1da08b1bb268 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -6,6 +6,7 @@ import copy import functools import time +import os try: from typing import TYPE_CHECKING @@ -17,7 +18,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from .._constants import KnownAuthorities +from .._constants import KnownAuthorities, EnvironmentVariables try: ABC = abc.ABC @@ -34,7 +35,8 @@ class AadClientBase(ABC): def __init__(self, tenant_id, client_id, cache=None, **kwargs): # type: (str, str, Optional[TokenCache], **Any) -> None - authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) if authority[-1] == "/": authority = authority[:-1] token_endpoint = "https://" + "/".join((authority, tenant_id, "oauth2/v2.0/token")) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index ea49aa6aa3ba..df066032a539 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -7,6 +7,7 @@ """ import abc import time +import os import msal from azure.core.credentials import AccessToken @@ -14,7 +15,7 @@ from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter -from .._constants import KnownAuthorities +from .._constants import KnownAuthorities, EnvironmentVariables try: ABC = abc.ABC @@ -37,7 +38,8 @@ class MsalCredential(ABC): def __init__(self, client_id, client_credential=None, **kwargs): # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None tenant_id = kwargs.pop("tenant_id", "organizations") - authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) self._base_url = "https://" + "/".join((authority.strip("/"), tenant_id.strip("/"))) self._client_credential = client_credential self._client_id = client_id diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 1091f5e4525a..24440cd93068 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -8,8 +8,9 @@ from msal import TokenCache -from .. import CredentialUnavailableError -from .._constants import KnownAuthorities +from azure.core.exceptions import ClientAuthenticationError +from .._constants import KnownAuthorities, EnvironmentVariables + try: ABC = abc.ABC @@ -86,7 +87,8 @@ class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - self._authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD + self._authority = kwargs.pop("authority", None) or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,)) self._username = username self._tenant_id = kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 9be822fb7409..05aef971e3a8 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -52,7 +52,8 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD + authority = kwargs.pop("authority", None) or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( From 6f93277ac98093e2872933ea24474a1b4172be9a Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 3 Apr 2020 10:09:10 +0800 Subject: [PATCH 2/6] Add test of EnvVar AZURE_AUTHORITY_HOST --- .../azure/identity/_internal/shared_token_cache.py | 3 +-- sdk/identity/azure-identity/tests/test_default.py | 5 +++++ sdk/identity/azure-identity/tests/test_default_async.py | 5 +++++ .../azure-identity/tests/test_shared_cache_credential.py | 5 ++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 24440cd93068..49fd47ecb336 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -8,10 +8,9 @@ from msal import TokenCache -from azure.core.exceptions import ClientAuthenticationError +from .. import CredentialUnavailableError from .._constants import KnownAuthorities, EnvironmentVariables - try: ABC = abc.ABC except AttributeError: # Python 2.7, abc exists, but not ABC diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index b3ad7542b996..fe24cc965e15 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -75,6 +75,9 @@ def send(request, **_): EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", EnvironmentVariables.AZURE_TENANT_ID: "tenant_id", } + if os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST): + environment.update({EnvironmentVariables.AZURE_AUTHORITY_HOST: os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST)}) with patch("os.environ", environment): transport = Mock(send=send) access_token, _ = DefaultAzureCredential(authority=authority_kwarg, transport=transport).get_token("scope") @@ -99,6 +102,8 @@ def send(request, **_): # all credentials not representing managed identities should use a specified authority or default to public cloud exercise_credentials("authority.com") exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) + with patch('os.environ', {EnvironmentVariables.AZURE_AUTHORITY_HOST: "localhost.com"}): + exercise_credentials(None, os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST)) def test_exclude_options(): diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 27fa7744db72..8b3cf86ed908 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -72,6 +72,9 @@ async def send(request, **_): EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", EnvironmentVariables.AZURE_TENANT_ID: "tenant_id", } + if os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST): + environment.update({EnvironmentVariables.AZURE_AUTHORITY_HOST: os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST)}) with patch("os.environ", environment): transport = Mock(send=send) if authority_kwarg: @@ -104,6 +107,8 @@ async def send(request, **_): # all credentials not representing managed identities should use a specified authority or default to public cloud await exercise_credentials("authority.com") await exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) + with patch('os.environ', {EnvironmentVariables.AZURE_AUTHORITY_HOST: "localhost.com"}): + await exercise_credentials(None, os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST)) def test_exclude_options(): diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index a6b4fea00ee1..65bf0bed2899 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -5,6 +5,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( KNOWN_ALIASES, MULTIPLE_ACCOUNTS, @@ -15,6 +16,7 @@ from azure.identity._internal.user_agent import USER_AGENT from msal import TokenCache import pytest +import os try: from unittest.mock import Mock @@ -473,7 +475,8 @@ def get_account_event( foci="1", ), "client_id": client_id, - "token_endpoint": "https://" + "/".join((authority or KnownAuthorities.AZURE_PUBLIC_CLOUD, utid, "/path")), + "token_endpoint": "https://" + "/".join((authority or os.environ.get( + EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD), utid, "/path")), "scope": scopes or ["scope"], } From 23842f81daeb92cedf1e8502b9d207f7db6b7237 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 6 Apr 2020 09:56:37 -0700 Subject: [PATCH 3/6] simplify DefaultAzureCredential tests --- .../azure-identity/tests/test_default.py | 102 +++++++--------- .../tests/test_default_async.py | 109 +++++++----------- 2 files changed, 85 insertions(+), 126 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index fe24cc965e15..842611afc7e6 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -45,65 +45,49 @@ def test_iterates_only_once(): assert successful_credential.get_token.call_count == n + 1 -def test_default_credential_authority(): - expected_access_token = "***" - response = mock_response( - json_payload={ - "access_token": expected_access_token, - "expires_in": 0, - "expires_on": 42, - "not_before": 0, - "resource": "scope", - "token_type": "Bearer", - } - ) - - def exercise_credentials(authority_kwarg, expected_authority=None): - expected_authority = expected_authority or authority_kwarg - - def send(request, **_): - url = urlparse(request.url) - assert url.scheme == "https", "Unexpected scheme '{}'".format(url.scheme) - assert url.netloc == expected_authority, "Expected authority '{}', actual was '{}'".format( - expected_authority, url.netloc - ) - return response - - # environment credential configured with client secret should respect authority - environment = { - EnvironmentVariables.AZURE_CLIENT_ID: "client_id", - EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", - EnvironmentVariables.AZURE_TENANT_ID: "tenant_id", - } - if os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST): - environment.update({EnvironmentVariables.AZURE_AUTHORITY_HOST: os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST)}) - with patch("os.environ", environment): - transport = Mock(send=send) - access_token, _ = DefaultAzureCredential(authority=authority_kwarg, transport=transport).get_token("scope") - assert access_token == expected_access_token - - # managed identity credential should ignore authority - with patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://some.url"}): - transport = Mock(send=lambda *_, **__: response) - access_token, _ = DefaultAzureCredential(authority=authority_kwarg, transport=transport).get_token("scope") - assert access_token == expected_access_token - - # shared cache credential should respect authority - upn = os.environ.get(EnvironmentVariables.AZURE_USERNAME, "spam@eggs") # preferring environment values to - tenant = os.environ.get(EnvironmentVariables.AZURE_TENANT_ID, "tenant") # prevent failure during live runs - account = get_account_event(username=upn, uid="guid", utid=tenant, authority=authority_kwarg) - cache = populated_cache(account) - with patch.object(SharedTokenCacheCredential, "supported"): - credential = DefaultAzureCredential(_cache=cache, authority=authority_kwarg, transport=Mock(send=send)) - access_token, _ = credential.get_token("scope") - assert access_token == expected_access_token - - # all credentials not representing managed identities should use a specified authority or default to public cloud - exercise_credentials("authority.com") - exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) - with patch('os.environ', {EnvironmentVariables.AZURE_AUTHORITY_HOST: "localhost.com"}): - exercise_credentials(None, os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST)) +def test__authority(): + """the credential should accept authority configuration by keyword argument or environment""" + + def test_initialization(mock_credential, expect_argument): + authority = "localhost" + + DefaultAzureCredential(authority=authority) + assert mock_credential.call_count == 1 + + # N.B. if os.environ has been patched somewhere in the stack, that patch is in place here + environment = dict(os.environ, **{EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}) + with patch.dict(DefaultAzureCredential.__module__ + ".os.environ", environment, clear=True): + DefaultAzureCredential() + assert mock_credential.call_count == 2 + + for _, kwargs in mock_credential.call_args_list: + if expect_argument: + assert kwargs["authority"] == authority + else: + assert "authority" not in kwargs + + # authority should be passed to EnvironmentCredential as a keyword argument + environment = {var: "foo" for var in EnvironmentVariables.CLIENT_SECRET_VARS} + with patch(DefaultAzureCredential.__module__ + ".EnvironmentCredential") as mock_credential: + with patch.dict("os.environ", environment, clear=True): + test_initialization(mock_credential, expect_argument=True) + + # authority should be passed to SharedTokenCacheCredential as a keyword argument + with patch(DefaultAzureCredential.__module__ + ".SharedTokenCacheCredential") as mock_credential: + mock_credential.supported = lambda: True + test_initialization(mock_credential, expect_argument=True) + + # authority should not be passed to ManagedIdentityCredential + with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential: + with patch.dict("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "localhost"}, clear=True): + test_initialization(mock_credential, expect_argument=False) + + # authority should not be passed to AzureCliCredential + with patch(DefaultAzureCredential.__module__ + ".AzureCliCredential") as mock_credential: + with patch(DefaultAzureCredential.__module__ + ".SharedTokenCacheCredential") as shared_cache: + shared_cache.supported = lambda: False + with patch.dict("os.environ", {}, clear=True): + test_initialization(mock_credential, expect_argument=False) def test_exclude_options(): diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 8b3cf86ed908..0472d351d8fd 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -40,75 +40,50 @@ async def test_iterates_only_once(): assert successful_credential.get_token.call_count == n + 1 -@pytest.mark.asyncio -async def test_default_credential_authority(): - authority = "authority.com" - expected_access_token = "***" - response = mock_response( - json_payload={ - "access_token": expected_access_token, - "expires_in": 0, - "expires_on": 42, - "not_before": 0, - "resource": "scope", - "token_type": "Bearer", - } - ) +def test_authority(): + """the credential should accept authority configuration by keyword argument or environment""" - async def exercise_credentials(authority_kwarg, expected_authority=None): - expected_authority = expected_authority or authority_kwarg - - async def send(request, **_): - url = urlparse(request.url) - assert url.scheme == "https", "Unexpected scheme '{}'".format(url.scheme) - assert url.netloc == expected_authority, "Expected authority '{}', actual was '{}'".format( - expected_authority, url.netloc - ) - return response - - # environment credential configured with client secret should respect authority - environment = { - EnvironmentVariables.AZURE_CLIENT_ID: "client_id", - EnvironmentVariables.AZURE_CLIENT_SECRET: "secret", - EnvironmentVariables.AZURE_TENANT_ID: "tenant_id", - } - if os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST): - environment.update({EnvironmentVariables.AZURE_AUTHORITY_HOST: os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST)}) - with patch("os.environ", environment): - transport = Mock(send=send) - if authority_kwarg: - credential = DefaultAzureCredential(authority=authority_kwarg, transport=transport) - else: - credential = DefaultAzureCredential(transport=transport) - access_token, _ = await credential.get_token("scope") - assert access_token == expected_access_token - - # managed identity credential should ignore authority - with patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "https://some.url"}): - transport = Mock(send=wrap_in_future(lambda *_, **__: response)) - if authority_kwarg: - credential = DefaultAzureCredential(authority=authority_kwarg, transport=transport) + def test_initialization(mock_credential, expect_argument): + authority = "localhost" + + DefaultAzureCredential(authority=authority) + assert mock_credential.call_count == 1 + + # N.B. if os.environ has been patched somewhere in the stack, that patch is in place here + environment = dict(os.environ, **{EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}) + with patch.dict(DefaultAzureCredential.__module__ + ".os.environ", environment, clear=True): + DefaultAzureCredential() + assert mock_credential.call_count == 2 + + for _, kwargs in mock_credential.call_args_list: + if expect_argument: + assert kwargs["authority"] == authority else: - credential = DefaultAzureCredential(transport=transport) - access_token, _ = await credential.get_token("scope") - assert access_token == expected_access_token - - # shared cache credential should respect authority - upn = os.environ.get(EnvironmentVariables.AZURE_USERNAME, "spam@eggs") # preferring environment values to - tenant = os.environ.get(EnvironmentVariables.AZURE_TENANT_ID, "tenant") # prevent failure during live runs - account = get_account_event(username=upn, uid="guid", utid=tenant, authority=authority_kwarg) - cache = populated_cache(account) - with patch.object(SharedTokenCacheCredential, "supported"): - credential = DefaultAzureCredential(_cache=cache, authority=authority_kwarg, transport=Mock(send=send)) - access_token, _ = await credential.get_token("scope") - assert access_token == expected_access_token - - # all credentials not representing managed identities should use a specified authority or default to public cloud - await exercise_credentials("authority.com") - await exercise_credentials(None, KnownAuthorities.AZURE_PUBLIC_CLOUD) - with patch('os.environ', {EnvironmentVariables.AZURE_AUTHORITY_HOST: "localhost.com"}): - await exercise_credentials(None, os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST)) + assert "authority" not in kwargs + + # authority should be passed to EnvironmentCredential as a keyword argument + environment = {var: "foo" for var in EnvironmentVariables.CLIENT_SECRET_VARS} + with patch(DefaultAzureCredential.__module__ + ".EnvironmentCredential") as mock_credential: + with patch.dict("os.environ", environment, clear=True): + test_initialization(mock_credential, expect_argument=True) + + # authority should be passed to SharedTokenCacheCredential as a keyword argument + with patch(DefaultAzureCredential.__module__ + ".SharedTokenCacheCredential") as mock_credential: + mock_credential.supported = lambda: True + with patch.dict("os.environ", {}, clear=True): + test_initialization(mock_credential, expect_argument=True) + + # authority should not be passed to ManagedIdentityCredential + with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential: + with patch.dict("os.environ", {EnvironmentVariables.MSI_ENDPOINT: "_"}, clear=True): + test_initialization(mock_credential, expect_argument=False) + + # authority should not be passed to AzureCliCredential + with patch(DefaultAzureCredential.__module__ + ".AzureCliCredential") as mock_credential: + with patch(DefaultAzureCredential.__module__ + ".SharedTokenCacheCredential") as shared_cache: + shared_cache.supported = lambda: False + with patch.dict("os.environ", {}, clear=True): + test_initialization(mock_credential, expect_argument=False) def test_exclude_options(): From 64d1604d54daee2b390778d548882edd3d5e9b1c Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 6 Apr 2020 09:15:54 -0700 Subject: [PATCH 4/6] tests --- .../azure-identity/tests/test_aad_client.py | 11 +++++-- .../tests/test_aad_client_async.py | 9 +++++- .../azure-identity/tests/test_authn_client.py | 10 +++++- .../tests/test_authn_client_async.py | 8 ++++- .../tests/test_certificate_credential.py | 13 ++++++-- .../test_certificate_credential_async.py | 11 ++++++- .../tests/test_client_secret_credential.py | 32 +++++++++++++++++-- .../test_client_secret_credential_async.py | 31 +++++++++++++++++- .../tests/test_environment_credential.py | 22 +++++++++++++ .../test_environment_credential_async.py | 25 ++++++++++++++- .../tests/test_shared_cache_credential.py | 27 +++++++++++++--- .../test_shared_cache_credential_async.py | 20 ++++++++++++ 12 files changed, 202 insertions(+), 17 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 7a165b1dc17d..5e3012e06061 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -5,6 +5,7 @@ import functools from azure.core.exceptions import ClientAuthenticationError +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.aad_client import AadClient import pytest from six.moves.urllib_parse import urlparse @@ -12,9 +13,9 @@ from helpers import mock_response try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore class MockClient(AadClient): @@ -113,3 +114,9 @@ def send(request, **_): client.obtain_token_by_authorization_code("code", "uri", "scope") client.obtain_token_by_refresh_token("refresh token", "scope") + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send)) + client.obtain_token_by_authorization_code("code", "uri", "scope") + client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index ca1b59bb03a7..a2d24ce8e3d9 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from unittest.mock import Mock +from unittest.mock import Mock, patch from urllib.parse import urlparse +from azure.identity._constants import EnvironmentVariables from azure.identity.aio._internal.aad_client import AadClient import pytest @@ -57,3 +58,9 @@ async def send(request, **_): await client.obtain_token_by_authorization_code("code", "uri", "scope") await client.obtain_token_by_refresh_token("refresh token", "scope") + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + client = AadClient(tenant_id=tenant, client_id="client id", transport=Mock(send=send)) + await client.obtain_token_by_authorization_code("code", "uri", "scope") + await client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_authn_client.py b/sdk/identity/azure-identity/tests/test_authn_client.py index 44f58d3e5012..e2772e130d27 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client.py +++ b/sdk/identity/azure-identity/tests/test_authn_client.py @@ -14,6 +14,7 @@ from azure.core.credentials import AccessToken from azure.identity._authn_client import AuthnClient +from azure.identity._constants import EnvironmentVariables from six.moves.urllib_parse import urlparse from helpers import mock_response @@ -205,7 +206,7 @@ def mock_send(request, **kwargs): def test_request_url(): - authority = "authority.com" + authority = "localhost" tenant = "expected_tenant" def validate_url(url): @@ -222,3 +223,10 @@ def mock_send(request, **kwargs): client.request_token(("scope",)) request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") validate_url(request.url) + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send)) + client.request_token(("scope",)) + request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") + validate_url(request.url) diff --git a/sdk/identity/azure-identity/tests/test_authn_client_async.py b/sdk/identity/azure-identity/tests/test_authn_client_async.py index b2982c2190b5..7966787d9b52 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client_async.py +++ b/sdk/identity/azure-identity/tests/test_authn_client_async.py @@ -3,10 +3,11 @@ # Licensed under the MIT License. # ------------------------------------ import asyncio -from unittest.mock import Mock +from unittest.mock import Mock, patch from urllib.parse import urlparse import pytest +from azure.identity._constants import EnvironmentVariables from azure.identity.aio._authn_client import AsyncAuthnClient from helpers import mock_response @@ -27,3 +28,8 @@ def mock_send(request, **kwargs): client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)), authority=authority) await client.request_token(("scope",)) + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send))) + await client.request_token(("scope",)) diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 233f457d0434..c861ed73d936 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -7,6 +7,7 @@ from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import CertificateCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -18,9 +19,9 @@ from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request, validating_transport try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") CERT_WITH_PASSWORD_PATH = os.path.join(os.path.dirname(__file__), "certificate-with-password.pem") @@ -84,6 +85,14 @@ def mock_send(request, **kwargs): token = cred.get_token("scope") assert token.token == access_token + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = CertificateCredential( + tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send) + ) + credential.get_token("scope") + assert token.token == access_token + @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) def test_request_body(cert_path, cert_password): diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index ed278e79ddd1..1efb467a56be 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -2,10 +2,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from unittest.mock import Mock +from unittest.mock import Mock, patch from urllib.parse import urlparse from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import CertificateCredential @@ -98,6 +99,14 @@ async def mock_send(request, **kwargs): token = await cred.get_token("scope") assert token.token == access_token + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = CertificateCredential( + tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send) + ) + await credential.get_token("scope") + assert token.token == access_token + @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index 6e28f90840bb..e92c8cccfdb8 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -7,15 +7,17 @@ from azure.core.credentials import AccessToken from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import ClientSecretCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import pytest +from six.moves.urllib_parse import urlparse from helpers import build_aad_response, mock_response, Request, validating_transport try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore def test_no_scopes(): @@ -78,6 +80,32 @@ def test_client_secret_credential(): assert token.token == access_token +def test_request_url(): + authority = "localhost" + tenant_id = "expected_tenant" + access_token = "***" + + def mock_send(request, **kwargs): + parsed = urlparse(request.url) + assert parsed.scheme == "https" + assert parsed.netloc == authority + assert parsed.path.startswith("/" + tenant_id) + + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + credential = ClientSecretCredential( + tenant_id, "client-id", "secret", transport=Mock(send=mock_send), authority=authority + ) + token = credential.get_token("scope") + assert token.token == access_token + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = ClientSecretCredential(tenant_id, "client-id", "secret", transport=Mock(send=mock_send)) + credential.get_token("scope") + assert token.token == access_token + + def test_cache(): expired = "this token's expired" now = int(time.time()) diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 188bc039a550..2253910fafb2 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -4,10 +4,12 @@ # ------------------------------------ import asyncio import time -from unittest.mock import Mock +from unittest.mock import Mock, patch +from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import ClientSecretCredential from helpers import build_aad_response, mock_response, Request @@ -104,6 +106,33 @@ async def test_client_secret_credential(): assert token.token == access_token +@pytest.mark.asyncio +async def test_request_url(): + authority = "localhost" + tenant_id = "expected_tenant" + access_token = "***" + + async def mock_send(request, **kwargs): + parsed = urlparse(request.url) + assert parsed.scheme == "https" + assert parsed.netloc == authority + assert parsed.path.startswith("/" + tenant_id) + + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) + + credential = ClientSecretCredential( + tenant_id, "client-id", "secret", transport=Mock(send=mock_send), authority=authority + ) + token = await credential.get_token("scope") + assert token.token == access_token + + # authority can be configured via environment variable + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = ClientSecretCredential(tenant_id, "client-id", "secret", transport=Mock(send=mock_send)) + await credential.get_token("scope") + assert token.token == access_token + + @pytest.mark.asyncio async def test_cache(): expired = "this token's expired" diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 03c65f8f9b9c..ebe77ed76571 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -31,3 +31,25 @@ def test_incomplete_configuration(): with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: EnvironmentCredential().get_token("scope") + + +@pytest.mark.parametrize( + "credential_name,environment_variables", + ( + ("ClientSecretCredential", EnvironmentVariables.CLIENT_SECRET_VARS), + ("CertificateCredential", EnvironmentVariables.CERT_VARS), + ("UsernamePasswordCredential", EnvironmentVariables.USERNAME_PASSWORD_VARS), + ), +) +def test_passes_authority_argument(credential_name, environment_variables): + """the credential pass the 'authority' keyword argument to its inner credential""" + + authority = "authority" + + with mock.patch.dict("os.environ", {variable: "foo" for variable in environment_variables}, clear=True): + with mock.patch(EnvironmentCredential.__module__ + "." + credential_name) as mock_credential: + EnvironmentCredential(authority=authority) + + assert mock_credential.call_count == 1 + _, kwargs = mock_credential.call_args + assert kwargs["authority"] == authority diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 6d8d733ecdaa..7c95172d9752 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -5,7 +5,9 @@ import itertools import os -from azure.identity import CredentialUnavailableError, EnvironmentCredential +from azure.identity import CredentialUnavailableError +from azure.identity.aio import EnvironmentCredential +from azure.identity._constants import EnvironmentVariables import pytest from helpers import mock @@ -24,3 +26,24 @@ async def test_incomplete_configuration(): with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: await EnvironmentCredential().get_token("scope") + + +@pytest.mark.parametrize( + "credential_name,environment_variables", + ( + ("ClientSecretCredential", EnvironmentVariables.CLIENT_SECRET_VARS), + ("CertificateCredential", EnvironmentVariables.CERT_VARS), + ), +) +def test_passes_authority_argument(credential_name, environment_variables): + """the credential pass the 'authority' keyword argument to its inner credential""" + + authority = "authority" + + with mock.patch.dict("os.environ", {variable: "foo" for variable in environment_variables}, clear=True): + with mock.patch(EnvironmentCredential.__module__ + "." + credential_name) as mock_credential: + EnvironmentCredential(authority=authority) + + assert mock_credential.call_count == 1 + _, kwargs = mock_credential.call_args + assert kwargs["authority"] == authority diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 65bf0bed2899..e6f7572be250 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -13,15 +13,15 @@ NO_ACCOUNTS, NO_MATCHING_ACCOUNTS, ) +from azure.identity._internal import get_default_authority from azure.identity._internal.user_agent import USER_AGENT from msal import TokenCache import pytest -import os try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport @@ -463,6 +463,24 @@ def test_authority_with_no_known_alias(): assert token.token == expected_access_token +def test_authority_environment_variable(): + """the credential should accept an authority by environment variable when none is otherwise specified""" + + authority = "localhost" + expected_access_token = "access-token" + expected_refresh_token = "refresh-token" + account = get_account_event("spam@eggs", "uid", "tenant", authority=authority, refresh_token=expected_refresh_token) + cache = populated_cache(account) + transport = validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = SharedTokenCacheCredential(transport=transport, _cache=cache) + token = credential.get_token("scope") + assert token.token == expected_access_token + + def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None ): @@ -475,8 +493,7 @@ def get_account_event( foci="1", ), "client_id": client_id, - "token_endpoint": "https://" + "/".join((authority or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD), utid, "/path")), + "token_endpoint": "https://" + "/".join((authority or get_default_authority(), utid, "/path",)), "scope": scopes or ["scope"], } diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index f200895ad254..9ee971517c98 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -8,6 +8,7 @@ from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import CredentialUnavailableError, KnownAuthorities from azure.identity.aio import SharedTokenCacheCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( KNOWN_ALIASES, MULTIPLE_ACCOUNTS, @@ -525,3 +526,22 @@ async def test_authority_with_no_known_alias(): credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) token = await credential.get_token("scope") assert token.token == expected_access_token + + +@pytest.mark.asyncio +async def test_authority_environment_variable(): + """the credential should accept an authority by environment variable when none is otherwise specified""" + + authority = "localhost" + expected_access_token = "access-token" + expected_refresh_token = "refresh-token" + account = get_account_event("spam@eggs", "uid", "tenant", authority=authority, refresh_token=expected_refresh_token) + cache = populated_cache(account) + transport = async_validating_transport( + requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], + responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], + ) + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + credential = SharedTokenCacheCredential(transport=transport, _cache=cache) + token = await credential.get_token("scope") + assert token.token == expected_access_token From 60a9f4df063dd2cfbfc49b009731ae86c655de75 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 6 Apr 2020 12:19:01 -0700 Subject: [PATCH 5/6] centralize default authority --- .../azure-identity/azure/identity/_authn_client.py | 8 +++----- .../azure/identity/_credentials/default.py | 6 +++--- .../azure/identity/_internal/__init__.py | 10 ++++++++++ .../azure/identity/_internal/aad_client_base.py | 6 ++---- .../azure/identity/_internal/msal_credentials.py | 6 ++---- .../azure/identity/_internal/shared_token_cache.py | 6 +++--- .../azure/identity/aio/_credentials/default.py | 6 +++--- 7 files changed, 26 insertions(+), 22 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index bc55e0e66c8b..a45eb49a81a2 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -5,7 +5,6 @@ import abc import calendar import time -import os from msal import TokenCache @@ -23,7 +22,8 @@ UserAgentPolicy, ) from azure.core.pipeline.transport import RequestsTransport, HttpRequest -from ._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities, EnvironmentVariables +from ._constants import AZURE_CLI_CLIENT_ID +from ._internal import get_default_authority from ._internal.user_agent import USER_AGENT try: @@ -62,9 +62,7 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl else: if not tenant: raise ValueError("'tenant' is required") - if not authority: - authority = os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = authority or get_default_authority() self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 6516c13ed3ac..de1103e0a4b9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -5,7 +5,8 @@ import logging import os -from .._constants import EnvironmentVariables, KnownAuthorities +from .._constants import EnvironmentVariables +from .._internal import get_default_authority from .browser import InteractiveBrowserCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential @@ -61,8 +62,7 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or get_default_authority() shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 6d1e6292445c..a3cbd73d1586 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -2,6 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import os + +from .._constants import EnvironmentVariables, KnownAuthorities + + +def get_default_authority(): + return os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + + +# pylint:disable=wrong-import-position from .aad_client import AadClient from .aad_client_base import AadClientBase from .auth_code_redirect_handler import AuthCodeRedirectServer diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 1da08b1bb268..2fb8ef43f0e0 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -6,7 +6,6 @@ import copy import functools import time -import os try: from typing import TYPE_CHECKING @@ -18,7 +17,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from .._constants import KnownAuthorities, EnvironmentVariables +from . import get_default_authority try: ABC = abc.ABC @@ -35,8 +34,7 @@ class AadClientBase(ABC): def __init__(self, tenant_id, client_id, cache=None, **kwargs): # type: (str, str, Optional[TokenCache], **Any) -> None - authority = kwargs.pop("authority", None) or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or get_default_authority() if authority[-1] == "/": authority = authority[:-1] token_endpoint = "https://" + "/".join((authority, tenant_id, "oauth2/v2.0/token")) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index df066032a539..78552ec106d7 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -7,7 +7,6 @@ """ import abc import time -import os import msal from azure.core.credentials import AccessToken @@ -15,7 +14,7 @@ from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter -from .._constants import KnownAuthorities, EnvironmentVariables +from .._internal import get_default_authority try: ABC = abc.ABC @@ -38,8 +37,7 @@ class MsalCredential(ABC): def __init__(self, client_id, client_credential=None, **kwargs): # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None tenant_id = kwargs.pop("tenant_id", "organizations") - authority = kwargs.pop("authority", None) or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or get_default_authority() self._base_url = "https://" + "/".join((authority.strip("/"), tenant_id.strip("/"))) self._client_credential = client_credential self._client_id = client_id diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 49fd47ecb336..d874ba5701f2 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -9,7 +9,8 @@ from msal import TokenCache from .. import CredentialUnavailableError -from .._constants import KnownAuthorities, EnvironmentVariables +from .._constants import KnownAuthorities +from .._internal import get_default_authority try: ABC = abc.ABC @@ -86,8 +87,7 @@ class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - self._authority = kwargs.pop("authority", None) or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + self._authority = kwargs.pop("authority", None) or get_default_authority() self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,)) self._username = username self._tenant_id = kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 05aef971e3a8..3caf443bf793 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -6,7 +6,8 @@ import os from typing import TYPE_CHECKING -from ..._constants import EnvironmentVariables, KnownAuthorities +from ..._constants import EnvironmentVariables +from ..._internal import get_default_authority from .azure_cli import AzureCliCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential @@ -52,8 +53,7 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or os.environ.get( - EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + authority = kwargs.pop("authority", None) or get_default_authority() shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( From c4ea28507db60e2f77e5e6104648111d674a063c Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 6 Apr 2020 12:34:25 -0700 Subject: [PATCH 6/6] update changelog --- sdk/identity/azure-identity/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 111808b5847b..6a536182eb5a 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -15,6 +15,10 @@ state. ([#10243](https://github.com/Azure/azure-sdk-for-python/issues/10243)) cache is available but contains ambiguous or insufficient information. This causes `ChainedTokenCredential` to correctly try the next credential in the chain. ([#10631](https://github.com/Azure/azure-sdk-for-python/issues/10631)) +- The host of the Active Directory endpoint credentials should use can be set +in the environment variable `AZURE_AUTHORITY_HOST`. See +`azure.identity.KnownAuthorities` for a list of common values. +([#8094](https://github.com/Azure/azure-sdk-for-python/issues/8094)) ## 1.3.1 (2020-03-30)