diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index a16b8add2874..e163c5befa1b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -72,12 +72,24 @@ class DefaultAzureCredential(ChainedTokenCredential): :keyword str shared_cache_tenant_id: Preferred tenant for :class:`~azure.identity.SharedTokenCacheCredential`. Defaults to the value of environment variable AZURE_TENANT_ID, if any. :keyword str visual_studio_code_tenant_id: Tenant ID to use when authenticating with - :class:`~azure.identity.VisualStudioCodeCredential`. + :class:`~azure.identity.VisualStudioCodeCredential`. Defaults to the "Azure: Tenant" setting in VS Code's user + settings or, when that setting has no value, the "organizations" tenant, which supports only Azure Active + Directory work or school accounts. """ def __init__(self, **kwargs): # type: (**Any) -> None authority = kwargs.pop("authority", None) + + vscode_tenant_id = kwargs.pop( + "visual_studio_code_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) + ) + vscode_args = {} + if authority: + vscode_args["authority"] = authority + if vscode_tenant_id: + vscode_args["tenant_id"] = vscode_tenant_id + authority = normalize_authority(authority) if authority else get_default_authority() interactive_browser_tenant_id = kwargs.pop( @@ -93,10 +105,6 @@ def __init__(self, **kwargs): "shared_cache_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) ) - vscode_tenant_id = kwargs.pop( - "visual_studio_code_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) - ) - exclude_environment_credential = kwargs.pop("exclude_environment_credential", False) exclude_managed_identity_credential = kwargs.pop("exclude_managed_identity_credential", False) exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False) @@ -120,7 +128,7 @@ def __init__(self, **kwargs): except Exception as ex: # pylint:disable=broad-except _LOGGER.info("Shared token cache is unavailable: '%s'", ex) if not exclude_visual_studio_code_credential: - credentials.append(VisualStudioCodeCredential(tenant_id=vscode_tenant_id)) + credentials.append(VisualStudioCodeCredential(**vscode_args)) if not exclude_cli_credential: credentials.append(AzureCliCredential()) if not exclude_powershell_credential: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 8e07917ef5f0..db54a33bf18f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -2,47 +2,125 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import abc +import os import sys -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING from .._exceptions import CredentialUnavailableError -from .._constants import AZURE_VSCODE_CLIENT_ID -from .._internal import validate_tenant_id +from .._constants import AzureAuthorityHosts, AZURE_VSCODE_CLIENT_ID, EnvironmentVariables +from .._internal import normalize_authority, validate_tenant_id from .._internal.aad_client import AadClient from .._internal.get_token_mixin import GetTokenMixin if sys.platform.startswith("win"): - from .._internal.win_vscode_adapter import get_credentials + from .._internal.win_vscode_adapter import get_refresh_token, get_user_settings elif sys.platform.startswith("darwin"): - from .._internal.macos_vscode_adapter import get_credentials + from .._internal.macos_vscode_adapter import get_refresh_token, get_user_settings else: - from .._internal.linux_vscode_adapter import get_credentials + from .._internal.linux_vscode_adapter import get_refresh_token, get_user_settings if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from azure.core.credentials import AccessToken + from .._internal.aad_client import AadClientBase +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore -class VisualStudioCodeCredential(GetTokenMixin): - """Authenticates as the Azure user signed in to Visual Studio Code. - - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. - :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "organizations" - tenant, which supports only Azure Active Directory work or school accounts. - """ +class _VSCodeCredentialBase(ABC): def __init__(self, **kwargs): # type: (**Any) -> None - super(VisualStudioCodeCredential, self).__init__() + super(_VSCodeCredentialBase, self).__init__() + + user_settings = get_user_settings() + self._cloud = user_settings.get("azure.cloud", "AzureCloud") self._refresh_token = None - self._client = kwargs.pop("_client", None) - self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" - validate_tenant_id(self._tenant_id) + self._unavailable_reason = "" + + self._client = kwargs.get("_client") if not self._client: - self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) + self._initialize(user_settings, **kwargs) + if not (self._client or self._unavailable_reason): + self._unavailable_reason = "Initialization failed" + + @abc.abstractmethod + def _get_client(self, **kwargs): + # type: (**Any) -> AadClientBase + pass + + def _get_refresh_token(self): + # type: () -> str + if not self._refresh_token: + self._refresh_token = get_refresh_token(self._cloud) + if not self._refresh_token: + raise CredentialUnavailableError(message="Failed to get Azure user details from Visual Studio Code.") + return self._refresh_token + + def _initialize(self, vscode_user_settings, **kwargs): + # type: (Dict, **Any) -> None + """Build a client from kwargs merged with VS Code user settings. + + The first stable version of this credential defaulted to Public Cloud and the "organizations" + tenant when it failed to read VS Code user settings. That behavior is preserved here. + """ + + # Precedence for authority: + # 1) VisualStudioCodeCredential(authority=...) + # 2) $AZURE_AUTHORITY_HOST + # 3) authority matching VS Code's "azure.cloud" setting + # 4) default: Public Cloud + authority = kwargs.pop("authority", None) or os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST) + if not authority: + # the application didn't specify an authority, so we figure it out from VS Code settings + if self._cloud == "AzureCloud": + authority = AzureAuthorityHosts.AZURE_PUBLIC_CLOUD + elif self._cloud == "AzureChinaCloud": + authority = AzureAuthorityHosts.AZURE_CHINA + elif self._cloud == "AzureGermanCloud": + authority = AzureAuthorityHosts.AZURE_GERMANY + elif self._cloud == "AzureUSGovernment": + authority = AzureAuthorityHosts.AZURE_GOVERNMENT + else: + # If the value is anything else ("AzureCustomCloud" is the only other known value), + # we need the user to provide the authority because VS Code has no setting for it and + # we can't guess confidently. + self._unavailable_reason = ( + 'VS Code is configured to use a custom cloud. Set keyword argument "authority"' + + ' with the Azure Active Directory endpoint for cloud "{}"'.format(self._cloud) + ) + return + + # Precedence for tenant ID: + # 1) VisualStudioCodeCredential(tenant_id=...) + # 2) "azure.tenant" in VS Code user settings + # 3) default: organizations + tenant_id = kwargs.pop("tenant_id", None) or vscode_user_settings.get("azure.tenant", "organizations") + validate_tenant_id(tenant_id) + if tenant_id.lower() == "adfs": + self._unavailable_reason = "VisualStudioCodeCredential authentication unavailable. ADFS is not supported." + return + + self._client = self._get_client( + authority=normalize_authority(authority), client_id=AZURE_VSCODE_CLIENT_ID, tenant_id=tenant_id, **kwargs + ) + + +class VisualStudioCodeCredential(_VSCodeCredentialBase, GetTokenMixin): + """Authenticates as the Azure user signed in to Visual Studio Code. + + :keyword str authority: authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com". + This argument is required for a custom cloud and usually unnecessary otherwise. Defaults to the authority + matching the "Azure: Cloud" setting in VS Code's user settings or, when that setting has no value, the + authority for Azure Public Cloud. + :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "Azure: Tenant" + setting in VS Code's user settings or, when that setting has no value, the "organizations" tenant, which + supports only Azure Active Directory work or school accounts. + """ def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken @@ -55,21 +133,21 @@ def get_token(self, *scopes, **kwargs): :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual Studio Code """ - if self._tenant_id.lower() == "adfs": - raise CredentialUnavailableError( - message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported." - ) + if self._unavailable_reason: + raise CredentialUnavailableError(message=self._unavailable_reason) return super(VisualStudioCodeCredential, self).get_token(*scopes, **kwargs) def _acquire_token_silently(self, *scopes): # type: (*str) -> Optional[AccessToken] + self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes) def _request_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken - if not self._refresh_token: - self._refresh_token = get_credentials() - if not self._refresh_token: - raise CredentialUnavailableError(message="Failed to get Azure user details from Visual Studio Code.") + refresh_token = self._get_refresh_token() + self._client = cast(AadClient, self._client) + return self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) - return self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs) + def _get_client(self, **kwargs): + # type: (**Any) -> AadClient + return AadClient(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py index c981eb8f88eb..d4413fe85f92 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py @@ -49,22 +49,6 @@ class _SECRET_SCHEMA(ct.Structure): _libsecret = None # type: ignore -def _get_user_settings_path(): - app_data_folder = os.environ["HOME"] - return os.path.join(app_data_folder, ".config", "Code", "User", "settings.json") - - -def _get_user_settings(): - path = _get_user_settings_path() - try: - with open(path) as file: - data = json.load(file) - environment_name = data.get("azure.cloud", "AzureCloud") - return environment_name - except IOError: - return "AzureCloud" - - def _get_refresh_token(service_name, account_name): if not _libsecret: return None @@ -88,18 +72,26 @@ def _get_refresh_token(service_name, account_name): _c_str(account_name), None, ) - if err.value == 0: + if err.value == 0 and p_str: return p_str.decode("utf-8") return None -def get_credentials(): +def get_user_settings(): + try: + path = os.path.join(os.environ["HOME"], ".config", "Code", "User", "settings.json") + with open(path) as file: + return json.load(file) + except Exception as ex: # pylint:disable=broad-except + _LOGGER.debug('Exception reading VS Code user settings: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG)) + return {} + + +def get_refresh_token(cloud_name): try: - environment_name = _get_user_settings() - credentials = _get_refresh_token(VSCODE_CREDENTIALS_SECTION, environment_name) - return credentials - except Exception as ex: # pylint: disable=broad-except + return _get_refresh_token(VSCODE_CREDENTIALS_SECTION, cloud_name) + except Exception as ex: # pylint:disable=broad-except _LOGGER.debug( 'Exception retrieving VS Code credentials: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG) ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py index 6dc78ffefc93..05fd2496837e 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py @@ -11,36 +11,23 @@ _LOGGER = logging.getLogger(__name__) -def _get_user_settings_path(): - app_data_folder = os.environ["USER"] - return os.path.join(app_data_folder, "Library", "Application Support", "Code", "User", "settings.json") - - -def _get_user_settings(): - path = _get_user_settings_path() +def get_user_settings(): try: + path = os.path.join(os.environ["HOME"], "Library", "Application Support", "Code", "User", "settings.json") with open(path) as file: - data = json.load(file) - environment_name = data.get("azure.cloud", "AzureCloud") - return environment_name - except IOError: - return "AzureCloud" + return json.load(file) + except Exception as ex: # pylint:disable=broad-except + _LOGGER.debug('Exception reading VS Code user settings: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG)) + return {} -def _get_refresh_token(service_name, account_name): - key_chain = Keychain() +def get_refresh_token(cloud_name): try: - return key_chain.get_generic_password(service_name, account_name) + key_chain = Keychain() + return key_chain.get_generic_password(VSCODE_CREDENTIALS_SECTION, cloud_name) except KeychainError: return None - - -def get_credentials(): - try: - environment_name = _get_user_settings() - credentials = _get_refresh_token(VSCODE_CREDENTIALS_SECTION, environment_name) - return credentials - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: # pylint:disable=broad-except _LOGGER.debug( 'Exception retrieving VS Code credentials: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG) ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py index 9172d5dd836f..341ac263a9dd 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py @@ -57,31 +57,19 @@ def _read_credential(service_name, account_name): return None -def _get_user_settings_path(): - app_data_folder = os.environ["APPDATA"] - return os.path.join(app_data_folder, "Code", "User", "settings.json") - - -def _get_user_settings(): - path = _get_user_settings_path() +def get_user_settings(): try: + path = os.path.join(os.environ["APPDATA"], "Code", "User", "settings.json") with open(path) as file: - data = json.load(file) - environment_name = data.get("azure.cloud", "AzureCloud") - return environment_name - except IOError: - return "AzureCloud" - - -def _get_refresh_token(service_name, account_name): - return _read_credential(service_name, account_name) + return json.load(file) + except Exception as ex: # pylint:disable=broad-except + _LOGGER.debug('Exception reading VS Code user settings: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG)) + return {} -def get_credentials(): +def get_refresh_token(cloud_name): try: - environment_name = _get_user_settings() - credentials = _get_refresh_token(VSCODE_CREDENTIALS_SECTION, environment_name) - return credentials + return _read_credential(VSCODE_CREDENTIALS_SECTION, cloud_name) except Exception as ex: # pylint: disable=broad-except _LOGGER.debug( 'Exception retrieving VS Code credentials: "%s"', ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 48134e875d7f..cf5085556b9e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -57,16 +57,28 @@ class DefaultAzureCredential(ChainedTokenCredential): **False**. :keyword str managed_identity_client_id: The client ID of a user-assigned managed identity. Defaults to the value of the environment variable AZURE_CLIENT_ID, if any. If not specified, a system-assigned identity will be used. - :keyword str shared_cache_username: Preferred username for :class:`~azure.identity.SharedTokenCacheCredential`. + :keyword str shared_cache_username: Preferred username for :class:`~azure.identity.aio.SharedTokenCacheCredential`. Defaults to the value of environment variable AZURE_USERNAME, if any. - :keyword str shared_cache_tenant_id: Preferred tenant for :class:`~azure.identity.SharedTokenCacheCredential`. + :keyword str shared_cache_tenant_id: Preferred tenant for :class:`~azure.identity.aio.SharedTokenCacheCredential`. Defaults to the value of environment variable AZURE_TENANT_ID, if any. :keyword str visual_studio_code_tenant_id: Tenant ID to use when authenticating with - :class:`~azure.identity.VisualStudioCodeCredential`. + :class:`~azure.identity.aio.VisualStudioCodeCredential`. Defaults to the "Azure: Tenant" setting in VS Code's + user settings or, when that setting has no value, the "organizations" tenant, which supports only Azure Active + Directory work or school accounts. """ def __init__(self, **kwargs: "Any") -> None: authority = kwargs.pop("authority", None) + + vscode_tenant_id = kwargs.pop( + "visual_studio_code_tenant_id", os.environ.get(EnvironmentVariables.AZURE_TENANT_ID) + ) + vscode_args = {} + if authority: + vscode_args["authority"] = authority + if vscode_tenant_id: + vscode_args["tenant_id"] = vscode_tenant_id + authority = normalize_authority(authority) if authority else get_default_authority() shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) @@ -93,9 +105,7 @@ def __init__(self, **kwargs: "Any") -> None: if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, **kwargs)) if not exclude_managed_identity_credential: - credentials.append( - ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs) - ) + credentials.append(ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs)) if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported(): try: # username and/or tenant_id are only required when the cache contains tokens for multiple identities @@ -106,7 +116,7 @@ def __init__(self, **kwargs: "Any") -> None: except Exception as ex: # pylint:disable=broad-except _LOGGER.info("Shared token cache is unavailable: '%s'", ex) if not exclude_visual_studio_code_credential: - credentials.append(VisualStudioCodeCredential(tenant_id=vscode_tenant_id)) + credentials.append(VisualStudioCodeCredential(**vscode_args)) if not exclude_cli_credential: credentials.append(AzureCliCredential()) if not exclude_powershell_credential: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index 09d817216a89..e180090d40d8 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -2,15 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING from ..._exceptions import CredentialUnavailableError -from ..._constants import AZURE_VSCODE_CLIENT_ID from .._internal import AsyncContextManager from .._internal.aad_client import AadClient from .._internal.get_token_mixin import GetTokenMixin -from ..._credentials.vscode import get_credentials -from ..._internal import validate_tenant_id +from ..._credentials.vscode import _VSCodeCredentialBase if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports @@ -18,31 +16,24 @@ from azure.core.credentials import AccessToken -class VisualStudioCodeCredential(AsyncContextManager, GetTokenMixin): +class VisualStudioCodeCredential(_VSCodeCredentialBase, AsyncContextManager, GetTokenMixin): """Authenticates as the Azure user signed in to Visual Studio Code. - :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', - the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` - defines authorities for other clouds. - :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "organizations" - tenant, which supports only Azure Active Directory work or school accounts. + :keyword str authority: authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com". + This argument is required for a custom cloud and usually unnecessary otherwise. Defaults to the authority + matching the "Azure: Cloud" setting in VS Code's user settings or, when that setting has no value, the + authority for Azure Public Cloud. + :keyword str tenant_id: ID of the tenant the credential should authenticate in. Defaults to the "Azure: Tenant" + setting in VS Code's user settings or, when that setting has no value, the "organizations" tenant, which + supports only Azure Active Directory work or school accounts. """ - def __init__(self, **kwargs: "Any") -> None: - super().__init__() - self._refresh_token = None - self._client = kwargs.pop("_client", None) - self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" - validate_tenant_id(self._tenant_id) - if not self._client: - self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) - - async def __aenter__(self): + async def __aenter__(self) -> "VisualStudioCodeCredential": if self._client: await self._client.__aenter__() return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" if self._client: @@ -58,19 +49,21 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual Studio Code """ - if self._tenant_id.lower() == "adfs": - raise CredentialUnavailableError( - message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported." - ) + if self._unavailable_reason: + raise CredentialUnavailableError(message=self._unavailable_reason) + if not self._client: + raise CredentialUnavailableError("Initialization failed") + return await super().get_token(*scopes, **kwargs) async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]": + self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes) async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": - if not self._refresh_token: - self._refresh_token = get_credentials() - if not self._refresh_token: - raise CredentialUnavailableError(message="Failed to get Azure user details from Visual Studio Code.") + refresh_token = self._get_refresh_token() + self._client = cast(AadClient, self._client) + return await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) - return await self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs) + def _get_client(self, **kwargs: "Any") -> AadClient: + return AadClient(**kwargs) diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index e3d6348de4d5..2805c54f3135 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -99,7 +99,7 @@ def assert_matches(self, request): def add_discrepancy(name, expected, actual): discrepancies.append("{}:\n\t expected: {}\n\t actual: {}".format(name, expected, actual)) - if self.base_url and self.base_url != request.url.split("?")[0]: + if self.base_url and not request.url.startswith(self.base_url): add_discrepancy("base url", self.base_url, request.url) if self.url and self.url != request.url: diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 7fd3459b8dae..f8e3750ce78c 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -229,28 +229,41 @@ def test_shared_cache_username(): assert token.token == expected_access_token -def test_vscode_tenant_id(): - """the credential should allow configuring a tenant ID for VisualStudioCodeCredential by kwarg or environment""" +def test_vscode_arguments(): + credential = DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential" - expected_args = {"tenant_id": "the-tenant"} + # DefaultAzureCredential shouldn't specify a default authority or tenant to VisualStudioCodeCredential + with patch(credential) as mock_credential: + DefaultAzureCredential() + mock_credential.assert_called_once_with() - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: - DefaultAzureCredential(visual_studio_code_tenant_id=expected_args["tenant_id"]) - mock_credential.assert_called_once_with(**expected_args) + tenant = {"tenant_id": "the-tenant"} + + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"]) + mock_credential.assert_called_once_with(**tenant) # tenant id can also be specified in $AZURE_TENANT_ID - with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: expected_args["tenant_id"]}, clear=True): - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: + with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: tenant["tenant_id"]}, clear=True): + with patch(credential) as mock_credential: DefaultAzureCredential() - mock_credential.assert_called_once_with(**expected_args) + mock_credential.assert_called_once_with(**tenant) # keyword argument should override environment variable - with patch.dict( - os.environ, {EnvironmentVariables.AZURE_TENANT_ID: "not-" + expected_args["tenant_id"]}, clear=True - ): - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: - DefaultAzureCredential(visual_studio_code_tenant_id=expected_args["tenant_id"]) - mock_credential.assert_called_once_with(**expected_args) + with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: "not-" + tenant["tenant_id"]}, clear=True): + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"]) + mock_credential.assert_called_once_with(**tenant) + + # DefaultAzureCredential should pass the authority kwarg along + authority = {"authority": "the-authority"} + with patch(credential) as mock_credential: + DefaultAzureCredential(**authority) + mock_credential.assert_called_once_with(**authority) + + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"], **authority) + mock_credential.assert_called_once_with(**dict(authority, **tenant)) @patch(DefaultAzureCredential.__module__ + ".SharedTokenCacheCredential") diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index ae860f3cbc36..11fecf54ce4f 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -2,7 +2,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from azure.identity.aio._credentials import vscode import os +from unittest import mock from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -125,6 +127,7 @@ def assert_credentials_not_present(chain, *credential_classes): credential = DefaultAzureCredential(exclude_powershell_credential=True) assert_credentials_not_present(credential, AzurePowerShellCredential) + @pytest.mark.asyncio async def test_shared_cache_tenant_id(): expected_access_token = "expected-access-token" @@ -211,28 +214,41 @@ async def test_shared_cache_username(): assert token.token == expected_access_token -def test_vscode_tenant_id(): - """the credential should allow configuring a tenant ID for VisualStudioCodeCredential by kwarg or environment""" +def test_vscode_arguments(): + credential = DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential" - expected_args = {"tenant_id": "the-tenant"} + # DefaultAzureCredential shouldn't specify a default authority or tenant to VisualStudioCodeCredential + with patch(credential) as mock_credential: + DefaultAzureCredential() + mock_credential.assert_called_once_with() - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: - DefaultAzureCredential(visual_studio_code_tenant_id=expected_args["tenant_id"]) - mock_credential.assert_called_once_with(**expected_args) + tenant = {"tenant_id": "the-tenant"} + + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"]) + mock_credential.assert_called_once_with(**tenant) # tenant id can also be specified in $AZURE_TENANT_ID - with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: expected_args["tenant_id"]}, clear=True): - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: + with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: tenant["tenant_id"]}, clear=True): + with patch(credential) as mock_credential: DefaultAzureCredential() - mock_credential.assert_called_once_with(**expected_args) + mock_credential.assert_called_once_with(**tenant) # keyword argument should override environment variable - with patch.dict( - os.environ, {EnvironmentVariables.AZURE_TENANT_ID: "not-" + expected_args["tenant_id"]}, clear=True - ): - with patch(DefaultAzureCredential.__module__ + ".VisualStudioCodeCredential") as mock_credential: - DefaultAzureCredential(visual_studio_code_tenant_id=expected_args["tenant_id"]) - mock_credential.assert_called_once_with(**expected_args) + with patch.dict(os.environ, {EnvironmentVariables.AZURE_TENANT_ID: "not-" + tenant["tenant_id"]}, clear=True): + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"]) + mock_credential.assert_called_once_with(**tenant) + + # DefaultAzureCredential should pass the authority kwarg along + authority = {"authority": "the-authority"} + with patch(credential) as mock_credential: + DefaultAzureCredential(**authority) + mock_credential.assert_called_once_with(**authority) + + with patch(credential) as mock_credential: + DefaultAzureCredential(visual_studio_code_tenant_id=tenant["tenant_id"], **authority) + mock_credential.assert_called_once_with(**dict(authority, **tenant)) @pytest.mark.asyncio diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index 094deb76432d..0675d8da2547 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -6,11 +6,10 @@ import time from azure.core.credentials import AccessToken -from azure.identity import CredentialUnavailableError, VisualStudioCodeCredential +from azure.identity import AzureAuthorityHosts, CredentialUnavailableError, VisualStudioCodeCredential from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT -from azure.identity._credentials.vscode import get_credentials import pytest from six.moves.urllib_parse import urlparse @@ -21,24 +20,65 @@ except ImportError: # python < 3.3 import mock +GET_REFRESH_TOKEN = VisualStudioCodeCredential.__module__ + ".get_refresh_token" +GET_USER_SETTINGS = VisualStudioCodeCredential.__module__ + ".get_user_settings" + + +def get_credential(user_settings=None, **kwargs): + # defaulting to empty user settings ensures tests work when real user settings are available + with mock.patch(GET_USER_SETTINGS, lambda: user_settings or {}): + return VisualStudioCodeCredential(**kwargs) + + +def test_tenant_id(): + def get_transport(expected_tenant): + return validating_transport( + requests=[ + Request(base_url="https://{}/{}".format(AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, expected_tenant)) + ], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + # credential should default to "organizations" tenant + transport = get_transport("organizations") + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + credential.get_token("scope") + assert transport.send.call_count == 1 + + # ... unless VS Code has a tenant configured + user_settings = {"azure.tenant": "vs-code-setting"} + transport = get_transport(user_settings["azure.tenant"]) + credential = get_credential(user_settings, transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + credential.get_token("scope") + assert transport.send.call_count == 1 + + # ... and a tenant specified by the application prevails over VS Code configuration + transport = get_transport("from-application") + credential = get_credential(user_settings, tenant_id="from-application", transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + credential.get_token("scope") + assert transport.send.call_count == 1 + def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" valid_ids = {"c878a2ab-8ef4-413b-83a0-199afb84d7fb", "contoso.onmicrosoft.com", "organizations", "common"} for tenant in valid_ids: - VisualStudioCodeCredential(tenant_id=tenant) + get_credential(tenant_id=tenant) invalid_ids = {"my tenant", "my_tenant", "/", "\\", '"my-tenant"', "'my-tenant'"} for tenant in invalid_ids: with pytest.raises(ValueError): - VisualStudioCodeCredential(tenant_id=tenant) + get_credential(tenant_id=tenant) def test_no_scopes(): """The credential should raise ValueError when get_token is called with no scopes""" - credential = VisualStudioCodeCredential() + credential = get_credential() with pytest.raises(ValueError): credential.get_token() @@ -49,10 +89,11 @@ def test_policies_configurable(): def send(*_, **__): return mock_response(json_payload=build_aad_response(access_token="**")) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): - credential = VisualStudioCodeCredential(policies=[policy], transport=mock.Mock(send=send)) + credential = get_credential(policies=[policy], transport=mock.Mock(send=send)) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): credential.get_token("scope") - assert policy.on_request.called + + assert policy.on_request.called def test_user_agent(): @@ -60,9 +101,8 @@ def test_user_agent(): requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], ) - - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): - credential = VisualStudioCodeCredential(transport=transport) + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): credential.get_token("scope") @@ -84,28 +124,24 @@ def mock_send(request, **kwargs): assert request.body["refresh_token"] == expected_refresh_token return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) - credential = VisualStudioCodeCredential( - tenant_id=tenant_id, transport=mock.Mock(send=mock_send), authority=authority - ) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_refresh_token): + credential = get_credential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send), authority=authority) + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_refresh_token): token = credential.get_token("scope") assert token.token == access_token # authority can be configured via environment variable - with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): - credential = VisualStudioCodeCredential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send)) - with mock.patch( - VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_refresh_token - ): + with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}): + credential = get_credential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send)) + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_refresh_token): credential.get_token("scope") assert token.token == access_token def test_credential_unavailable_error(): - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=None): - credential = VisualStudioCodeCredential() + credential = get_credential() + with mock.patch(GET_REFRESH_TOKEN, return_value=None): with pytest.raises(CredentialUnavailableError): - token = credential.get_token("scope") + credential.get_token("scope") def test_redeem_token(): @@ -116,8 +152,8 @@ def test_redeem_token(): mock_client.obtain_token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client.get_cached_access_token = mock.Mock(return_value=None) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_value): - credential = VisualStudioCodeCredential(_client=mock_client) + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_value): + credential = get_credential(_client=mock_client) token = credential.get_token("scope") assert token is expected_token mock_client.obtain_token_by_refresh_token.assert_called_with(("scope",), expected_value) @@ -132,8 +168,8 @@ def test_cache_refresh_token(): mock_client.get_cached_access_token = mock.Mock(return_value=None) mock_get_credentials = mock.Mock(return_value="VALUE") - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", mock_get_credentials): - credential = VisualStudioCodeCredential(_client=mock_client) + with mock.patch(GET_REFRESH_TOKEN, mock_get_credentials): + credential = get_credential(_client=mock_client) token = credential.get_token("scope") assert token is expected_token assert mock_get_credentials.call_count == 1 @@ -147,12 +183,12 @@ def test_no_obtain_token_if_cached(): mock_client = mock.Mock( obtain_token_by_refresh_token=mock.Mock(return_value=expected_token), - get_cached_access_token=mock.Mock(return_value=expected_token) + get_cached_access_token=mock.Mock(return_value=expected_token), ) - credential = VisualStudioCodeCredential(_client=mock_client) + credential = get_credential(_client=mock_client) with mock.patch( - VisualStudioCodeCredential.__module__ + ".get_credentials", + GET_REFRESH_TOKEN, mock.Mock(side_effect=Exception("credential should not acquire a new token")), ): token = credential.get_token("scope") @@ -162,33 +198,70 @@ def test_no_obtain_token_if_cached(): assert token.expires_on == expected_token.expires_on -@pytest.mark.skipif(not sys.platform.startswith("linux"), reason="This test only runs on Linux") -def test_segfault(): - from azure.identity._internal.linux_vscode_adapter import _get_refresh_token +def test_native_adapter(): + """Exercise the native adapter for the current OS""" - _get_refresh_token("test", "test") + if sys.platform.startswith("darwin"): + from azure.identity._internal.macos_vscode_adapter import get_refresh_token + elif sys.platform.startswith("linux"): + from azure.identity._internal.linux_vscode_adapter import get_refresh_token + elif sys.platform.startswith("win"): + from azure.identity._internal.win_vscode_adapter import get_refresh_token + else: + pytest.skip('unsupported platform "{}"'.format(sys.platform)) - -@pytest.mark.skipif(not sys.platform.startswith("darwin"), reason="This test only runs on MacOS") -def test_mac_keychain_valid_value(): - with mock.patch("msal_extensions.osx.Keychain.get_generic_password", return_value="VALUE"): - assert get_credentials() == "VALUE" - - -@pytest.mark.skipif(not sys.platform.startswith("darwin"), reason="This test only runs on MacOS") -def test_mac_keychain_error(): - from msal_extensions.osx import Keychain, KeychainError - - with mock.patch.object(Keychain, "get_generic_password", side_effect=KeychainError(-1)): - credential = VisualStudioCodeCredential() - with pytest.raises(CredentialUnavailableError): - token = credential.get_token("scope") + # the return value (None in CI, possibly something else on a dev machine) is irrelevant + # because the goal is simply to expose a native interop problem like a segfault + get_refresh_token("AzureCloud") def test_adfs(): """The credential should raise CredentialUnavailableError when configured for ADFS""" - credential = VisualStudioCodeCredential(tenant_id="adfs") + credential = get_credential(tenant_id="adfs") with pytest.raises(CredentialUnavailableError) as ex: credential.get_token("scope") assert "adfs" in ex.value.message.lower() + + +@pytest.mark.parametrize( + "cloud,authority", + ( + ("AzureCloud", AzureAuthorityHosts.AZURE_PUBLIC_CLOUD), + ("AzureChinaCloud", AzureAuthorityHosts.AZURE_CHINA), + ("AzureGermanCloud", AzureAuthorityHosts.AZURE_GERMANY), + ("AzureUSGovernment", AzureAuthorityHosts.AZURE_GOVERNMENT), + ), +) +def test_reads_cloud_settings(cloud, authority): + """the credential should read authority and tenant from VS Code settings when an application doesn't specify them""" + + expected_tenant = "tenant-id" + user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant} + + transport = validating_transport( + requests=[Request(base_url="https://{}/{}".format(authority, expected_tenant))], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + credential = get_credential(user_settings, transport=transport) + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + credential.get_token("scope") + + assert transport.send.call_count == 1 + + +def test_no_user_settings(): + """the credential should default to Public Cloud and "organizations" tenant when it can't read VS Code settings""" + + transport = validating_transport( + requests=[Request(base_url="https://{}/{}".format(AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, "organizations"))], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + credential.get_token("scope") + + assert transport.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 898b14bd57d7..6afeb58f655a 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse from azure.core.credentials import AccessToken -from azure.identity import CredentialUnavailableError +from azure.identity import AzureAuthorityHosts, CredentialUnavailableError from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import VisualStudioCodeCredential @@ -16,6 +16,46 @@ from helpers import build_aad_response, mock_response, Request from helpers_async import async_validating_transport, wrap_in_future +from test_vscode_credential import GET_REFRESH_TOKEN, GET_USER_SETTINGS + + +def get_credential(user_settings=None, **kwargs): + # defaulting to empty user settings ensures tests work when real user settings are available + with mock.patch(GET_USER_SETTINGS, lambda: user_settings or {}): + return VisualStudioCodeCredential(**kwargs) + + +@pytest.mark.asyncio +async def test_tenant_id(): + def get_transport(expected_tenant): + return async_validating_transport( + requests=[ + Request(base_url="https://{}/{}".format(AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, expected_tenant)) + ], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + # credential should default to "organizations" tenant + transport = get_transport("organizations") + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + await credential.get_token("scope") + assert transport.send.call_count == 1 + + # ... unless VS Code has a tenant configured + user_settings = {"azure.tenant": "vs-code-setting"} + transport = get_transport(user_settings["azure.tenant"]) + credential = get_credential(user_settings, transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + await credential.get_token("scope") + assert transport.send.call_count == 1 + + # ... and a tenant specified by the application prevails over VS Code configuration + transport = get_transport("from-application") + credential = get_credential(user_settings, tenant_id="from-application", transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + await credential.get_token("scope") + assert transport.send.call_count == 1 def test_tenant_id_validation(): @@ -23,19 +63,19 @@ def test_tenant_id_validation(): valid_ids = {"c878a2ab-8ef4-413b-83a0-199afb84d7fb", "contoso.onmicrosoft.com", "organizations", "common"} for tenant in valid_ids: - VisualStudioCodeCredential(tenant_id=tenant) + get_credential(tenant_id=tenant) invalid_ids = {"my tenant", "my_tenant", "/", "\\", '"my-tenant"', "'my-tenant'"} for tenant in invalid_ids: with pytest.raises(ValueError): - VisualStudioCodeCredential(tenant_id=tenant) + get_credential(tenant_id=tenant) @pytest.mark.asyncio async def test_no_scopes(): """The credential should raise ValueError when get_token is called with no scopes""" - credential = VisualStudioCodeCredential() + credential = get_credential() with pytest.raises(ValueError): await credential.get_token() @@ -47,10 +87,11 @@ async def test_policies_configurable(): async def send(*_, **__): return mock_response(json_payload=build_aad_response(access_token="**")) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): - credential = VisualStudioCodeCredential(policies=[policy], transport=mock.Mock(send=send)) + credential = get_credential(policies=[policy], transport=mock.Mock(send=send)) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") - assert policy.on_request.called + + assert policy.on_request.called @pytest.mark.asyncio @@ -59,9 +100,8 @@ async def test_user_agent(): requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], ) - - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value="VALUE"): - credential = VisualStudioCodeCredential(transport=transport) + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): await credential.get_token("scope") @@ -84,18 +124,18 @@ async def mock_send(request, **kwargs): assert request.body["refresh_token"] == expected_refresh_token return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) - credential = VisualStudioCodeCredential( + credential = get_credential( tenant_id=tenant_id, transport=mock.Mock(send=mock_send), authority=authority ) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_refresh_token): + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_refresh_token): token = await credential.get_token("scope") assert token.token == access_token # authority can be configured via environment variable with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): - credential = VisualStudioCodeCredential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send)) + credential = get_credential(tenant_id=tenant_id, transport=mock.Mock(send=mock_send)) with mock.patch( - VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_refresh_token + GET_REFRESH_TOKEN, return_value=expected_refresh_token ): await credential.get_token("scope") assert token.token == access_token @@ -103,10 +143,10 @@ async def mock_send(request, **kwargs): @pytest.mark.asyncio async def test_credential_unavailable_error(): - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=None): - credential = VisualStudioCodeCredential() + credential = get_credential() + with mock.patch(GET_REFRESH_TOKEN, return_value=None): with pytest.raises(CredentialUnavailableError): - token = await credential.get_token("scope") + await credential.get_token("scope") @pytest.mark.asyncio @@ -119,8 +159,8 @@ async def test_redeem_token(): mock_client.obtain_token_by_refresh_token = wrap_in_future(token_by_refresh_token) mock_client.get_cached_access_token = mock.Mock(return_value=None) - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", return_value=expected_value): - credential = VisualStudioCodeCredential(_client=mock_client) + with mock.patch(GET_REFRESH_TOKEN, return_value=expected_value): + credential = get_credential(_client=mock_client) token = await credential.get_token("scope") assert token is expected_token token_by_refresh_token.assert_called_with(("scope",), expected_value) @@ -136,11 +176,11 @@ async def test_cache_refresh_token(): mock_client.get_cached_access_token = mock.Mock(return_value=None) mock_get_credentials = mock.Mock(return_value="VALUE") - with mock.patch(VisualStudioCodeCredential.__module__ + ".get_credentials", mock_get_credentials): - credential = VisualStudioCodeCredential(_client=mock_client) - token = await credential.get_token("scope") + credential = get_credential(_client=mock_client) + with mock.patch(GET_REFRESH_TOKEN, mock_get_credentials): + await credential.get_token("scope") assert mock_get_credentials.call_count == 1 - token = await credential.get_token("scope") + await credential.get_token("scope") assert mock_get_credentials.call_count == 1 @@ -154,9 +194,9 @@ async def test_no_obtain_token_if_cached(): obtain_token_by_refresh_token=wrap_in_future(token_by_refresh_token) ) - credential = VisualStudioCodeCredential(_client=mock_client) + credential = get_credential(_client=mock_client) with mock.patch( - VisualStudioCodeCredential.__module__ + ".get_credentials", + GET_REFRESH_TOKEN, mock.Mock(side_effect=Exception("credential should not acquire a new token")), ): token = await credential.get_token("scope") @@ -170,7 +210,51 @@ async def test_no_obtain_token_if_cached(): async def test_adfs(): """The credential should raise CredentialUnavailableError when configured for ADFS""" - credential = VisualStudioCodeCredential(tenant_id="adfs") + credential = get_credential(tenant_id="adfs") with pytest.raises(CredentialUnavailableError) as ex: await credential.get_token("scope") assert "adfs" in ex.value.message.lower() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "cloud,authority", + ( + ("AzureCloud", AzureAuthorityHosts.AZURE_PUBLIC_CLOUD), + ("AzureChinaCloud", AzureAuthorityHosts.AZURE_CHINA), + ("AzureGermanCloud", AzureAuthorityHosts.AZURE_GERMANY), + ("AzureUSGovernment", AzureAuthorityHosts.AZURE_GOVERNMENT), + ), +) +async def test_reads_cloud_settings(cloud, authority): + """the credential should read authority and tenant from VS Code settings when an application doesn't specify them""" + + expected_tenant = "tenant-id" + user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant} + + transport = async_validating_transport( + requests=[Request(base_url="https://{}/{}".format(authority, expected_tenant))], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + credential = get_credential(user_settings, transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + await credential.get_token("scope") + + assert transport.send.call_count == 1 + + +@pytest.mark.asyncio +async def test_no_user_settings(): + """the credential should default to Public Cloud and "organizations" tenant when it can't read VS Code settings""" + + transport = async_validating_transport( + requests=[Request(base_url="https://{}/{}".format(AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, "organizations"))], + responses=[mock_response(json_payload=build_aad_response(access_token="**"))], + ) + + credential = get_credential(transport=transport) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + await credential.get_token("scope") + + assert transport.send.call_count == 1