diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index 2b165d0a0a52..e29a48854e68 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -22,7 +22,7 @@ UserAgentPolicy, ) from azure.core.pipeline.transport import RequestsTransport, HttpRequest -from ._constants import AZURE_CLI_CLIENT_ID +from ._constants import AZURE_CLI_CLIENT_ID, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from ._internal import get_default_authority, normalize_authority from ._internal.user_agent import USER_AGENT @@ -65,17 +65,32 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl authority = normalize_authority(authority) if authority else get_default_authority() self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache + self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY + self._token_refresh_offset = DEFAULT_REFRESH_OFFSET + self._last_refresh_time = 0 @property def auth_url(self): return self._auth_url + def should_refresh(self, token): + # type: (AccessToken) -> bool + """ check if the token needs refresh or not + """ + expires_on = int(token.expires_on) + now = int(time.time()) + if expires_on - now > self._token_refresh_offset: + return False + if now - self._last_refresh_time < self._token_refresh_retry_delay: + return False + return True + def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) for token in tokens: expires_on = int(token["expires_on"]) - if expires_on - 300 > int(time.time()): + if expires_on > int(time.time()): return AccessToken(token["secret"], expires_on) return None @@ -217,6 +232,7 @@ def request_token( # type: (...) -> AccessToken request = self._prepare_request(method, headers=headers, form_data=form_data, params=params) request_time = int(time.time()) + self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time response = self._pipeline.run(request, stream=False, **kwargs) token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index a47ebdeb9920..4d217d7dc716 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -7,6 +7,8 @@ AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" AZURE_VSCODE_CLIENT_ID = "aebc6443-996d-45c2-90f0-388ff96faa56" VSCODE_CREDENTIALS_SECTION = "VS Code Azure" +DEFAULT_REFRESH_OFFSET = 300 +DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30 class KnownAuthorities: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index 3568f8c921ce..b02e64baf684 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -64,7 +64,15 @@ def get_token(self, *scopes, **kwargs): self._authorization_code = None # auth codes are single-use return token - token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs) + token = self._client.get_cached_access_token(scopes) + if not token: + token = self._redeem_refresh_token(scopes, **kwargs) + elif self._client.should_refresh(token): + try: + self._redeem_refresh_token(scopes, **kwargs) + except Exception: # pylint: disable=broad-except + pass + if not token: raise ClientAuthenticationError( message="No authorization code, cached access token, or refresh token available." diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index 81adb2621a96..d88972c75265 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -48,6 +48,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) + elif self._client.should_refresh(token): + try: + self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) + except Exception: # pylint: disable=broad-except + pass return token def _get_auth_client(self, tenant_id, client_id, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py index 4e20c2bd900b..9e5e504ed785 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py @@ -49,6 +49,11 @@ def get_token(self, *scopes, **kwargs): token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) + elif self._client.should_refresh(token): + try: + self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) + except Exception: # pylint: disable=broad-except + pass return token def _get_auth_client(self, tenant_id, client_id, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 8d29f0aae70c..3cf055f816c1 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -170,28 +170,37 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument token = self._client.get_cached_token(scopes) if not token: - resource = scopes[0] - if resource.endswith("/.default"): - resource = resource[: -len("/.default")] - params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config) - + token = self._refresh_token(*scopes) + elif self._client.should_refresh(token): try: - token = self._client.request_token(scopes, method="GET", params=params) - except HttpResponseError as ex: - # 400 in response to a token request indicates managed identity is disabled, - # or the identity with the specified client_id is not available - if ex.status_code == 400: - self._endpoint_available = False - message = "ManagedIdentityCredential authentication unavailable. " - if self._identity_config: - message += "The requested identity has not been assigned to this resource." - else: - message += "No identity has been assigned to this resource." - six.raise_from(CredentialUnavailableError(message=message), ex) - - # any other error is unexpected - six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None) + token = self._refresh_token(*scopes) + except Exception: # pylint: disable=broad-except + pass + + return token + def _refresh_token(self, *scopes): + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config) + + try: + token = self._client.request_token(scopes, method="GET", params=params) + except HttpResponseError as ex: + # 400 in response to a token request indicates managed identity is disabled, + # or the identity with the specified client_id is not available + if ex.status_code == 400: + self._endpoint_available = False + message = "ManagedIdentityCredential authentication unavailable. " + if self._identity_config: + message += "The requested identity has not been assigned to this resource." + else: + message += "No identity has been assigned to this resource." + six.raise_from(CredentialUnavailableError(message=message), ex) + + # any other error is unexpected + six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None) return token @@ -227,16 +236,25 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument token = self._client.get_cached_token(scopes) if not token: - resource = scopes[0] - if resource.endswith("/.default"): - resource = resource[: -len("/.default")] - secret = os.environ.get(EnvironmentVariables.MSI_SECRET) - if secret: - # MSI_ENDPOINT and MSI_SECRET set -> App Service - token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret) - else: - # only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell) - token = self._request_legacy_token(scopes=scopes, resource=resource) + token = self._refresh_token(*scopes) + elif self._client.should_refresh(token): + try: + token = self._refresh_token(*scopes) + except Exception: # pylint: disable=broad-except + pass + return token + + def _refresh_token(self, *scopes): + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + secret = os.environ.get(EnvironmentVariables.MSI_SECRET) + if secret: + # MSI_ENDPOINT and MSI_SECRET set -> App Service + token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret) + else: + # only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell) + token = self._request_legacy_token(scopes=scopes, resource=resource) return token def _request_app_service_token(self, scopes, resource, secret): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py index d2c5e4ee6f6e..a5bcf3fc66c8 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py @@ -9,11 +9,11 @@ from .._internal.aad_client import AadClient if sys.platform.startswith("win"): - from .win_vscode_adapter import get_credentials + from .._internal.win_vscode_adapter import get_credentials elif sys.platform.startswith("darwin"): - from .macos_vscode_adapter import get_credentials + from .._internal.macos_vscode_adapter import get_credentials else: - from .linux_vscode_adapter import get_credentials + from .._internal.linux_vscode_adapter import get_credentials if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports @@ -47,9 +47,17 @@ def get_token(self, *scopes, **kwargs): token = self._client.get_cached_access_token(scopes) - if token: - return token + if not token: + token = self._redeem_refresh_token(scopes, **kwargs) + elif self._client.should_refresh(token): + try: + self._redeem_refresh_token(scopes, **kwargs) + except Exception: # pylint: disable=broad-except + pass + return token + def _redeem_refresh_token(self, scopes, **kwargs): + # type: (Sequence[str], **Any) -> Optional[AccessToken] if not self._refresh_token: self._refresh_token = get_credentials() if not self._refresh_token: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index e1af4f949626..819139f97ff7 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -16,6 +16,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from . import get_default_authority, normalize_authority +from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET try: from typing import TYPE_CHECKING @@ -48,13 +49,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs): self._cache = cache or TokenCache() self._client_id = client_id self._pipeline = self._build_pipeline(**kwargs) + self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY + self._token_refresh_offset = DEFAULT_REFRESH_OFFSET + self._last_refresh_time = 0 def get_cached_access_token(self, scopes, query=None): # type: (Sequence[str], Optional[dict]) -> Optional[AccessToken] tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query) for token in tokens: expires_on = int(token["expires_on"]) - if expires_on - 300 > int(time.time()): + if expires_on > int(time.time()): return AccessToken(token["secret"], expires_on) return None @@ -63,6 +67,19 @@ def get_cached_refresh_tokens(self, scopes): """Assumes all cached refresh tokens belong to the same user""" return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)) + def should_refresh(self, token): + # type: (AccessToken) -> bool + """ check if the token needs refresh or not + """ + expires_on = int(token.expires_on) + now = int(time.time()) + if expires_on - now > self._token_refresh_offset: + return False + if now - self._last_refresh_time < self._token_refresh_retry_delay: + return False + return True + + @abc.abstractmethod def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): pass @@ -85,6 +102,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs): def _process_response(self, response, request_time): # type: (PipelineResponse, int) -> AccessToken + self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/linux_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py similarity index 100% rename from sdk/identity/azure-identity/azure/identity/_credentials/linux_vscode_adapter.py rename to sdk/identity/azure-identity/azure/identity/_internal/linux_vscode_adapter.py diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/macos_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py similarity index 100% rename from sdk/identity/azure-identity/azure/identity/_credentials/macos_vscode_adapter.py rename to sdk/identity/azure-identity/azure/identity/_internal/macos_vscode_adapter.py diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/win_vscode_adapter.py b/sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py similarity index 100% rename from sdk/identity/azure-identity/azure/identity/_credentials/win_vscode_adapter.py rename to sdk/identity/azure-identity/azure/identity/_internal/win_vscode_adapter.py diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index 9cfe13bd9498..8f612e717ca8 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -75,6 +75,7 @@ async def request_token( ) -> AccessToken: request = self._prepare_request(method, headers=headers, form_data=form_data, params=params) request_time = int(time.time()) + self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time response = await self._pipeline.run(request, stream=False, **kwargs) token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index 0b5fbb53dc33..90edc002b243 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -80,7 +80,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": token = self._client.get_cached_access_token(scopes) if not token: token = await self._redeem_refresh_token(scopes, **kwargs) - + elif self._client.should_refresh(token): + try: + await self._redeem_refresh_token(scopes, **kwargs) + except Exception: # pylint: disable=broad-except + pass if not token: raise ClientAuthenticationError( message="No authorization code, cached access token, or refresh token available." diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 1b044a24c0e1..ade6bb8e7d8c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -54,6 +54,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) + elif self._client.should_refresh(token): + try: + await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) + except Exception: # pylint: disable=broad-except + pass return token def _get_auth_client(self, tenant_id, client_id, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index 87b5472760e6..767a80b3cf84 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -55,6 +55,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) + elif self._client.should_refresh(token): + try: + await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) + except Exception: # pylint: disable=broad-except + pass return token def _get_auth_client(self, tenant_id, client_id, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 5e5bf172f43e..6b17a55ada91 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -130,30 +130,39 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> AccessToken: # pyli token = self._client.get_cached_token(scopes) if not token: - resource = scopes[0] - if resource.endswith("/.default"): - resource = resource[: -len("/.default")] - params = {"api-version": "2018-02-01", "resource": resource, **self._identity_config} - + token = await self._refresh_token(*scopes) + elif self._client.should_refresh(token): try: - token = await self._client.request_token(scopes, method="GET", params=params) - except HttpResponseError as ex: - # 400 in response to a token request indicates managed identity is disabled, - # or the identity with the specified client_id is not available - if ex.status_code == 400: - self._endpoint_available = False - message = "ManagedIdentityCredential authentication unavailable. " - if self._identity_config: - message += "The requested identity has not been assigned to this resource." - else: - message += "No identity has been assigned to this resource." - raise CredentialUnavailableError(message=message) from ex - - # any other error is unexpected - raise ClientAuthenticationError(message=ex.message, response=ex.response) from None + token = await self._refresh_token(*scopes) + except Exception: # pylint: disable=broad-except + pass return token + async def _refresh_token(self, *scopes): + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + params = {"api-version": "2018-02-01", "resource": resource, **self._identity_config} + + try: + token = await self._client.request_token(scopes, method="GET", params=params) + except HttpResponseError as ex: + # 400 in response to a token request indicates managed identity is disabled, + # or the identity with the specified client_id is not available + if ex.status_code == 400: + self._endpoint_available = False + message = "ManagedIdentityCredential authentication unavailable. " + if self._identity_config: + message += "The requested identity has not been assigned to this resource." + else: + message += "No identity has been assigned to this resource." + raise CredentialUnavailableError(message=message) from ex + + # any other error is unexpected + raise ClientAuthenticationError(message=ex.message, response=ex.response) from None + return token + class MsiCredential(_AsyncManagedIdentityBase): """Authenticates via the MSI endpoint in an App Service or Cloud Shell environment. @@ -184,17 +193,26 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> AccessToken: # pyli token = self._client.get_cached_token(scopes) if not token: - resource = scopes[0] - if resource.endswith("/.default"): - resource = resource[: -len("/.default")] - - secret = os.environ.get(EnvironmentVariables.MSI_SECRET) - if secret: - # MSI_ENDPOINT and MSI_SECRET set -> App Service - token = await self._request_app_service_token(scopes=scopes, resource=resource, secret=secret) - else: - # only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell) - token = await self._request_legacy_token(scopes=scopes, resource=resource) + token = await self._refresh_token(*scopes) + elif self._client.should_refresh(token): + try: + token = await self._refresh_token(*scopes) + except Exception: # pylint: disable=broad-except + pass + return token + + async def _refresh_token(self, *scopes): + resource = scopes[0] + if resource.endswith("/.default"): + resource = resource[: -len("/.default")] + + secret = os.environ.get(EnvironmentVariables.MSI_SECRET) + if secret: + # MSI_ENDPOINT and MSI_SECRET set -> App Service + token = await self._request_app_service_token(scopes=scopes, resource=resource, secret=secret) + else: + # only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell) + token = await self._request_legacy_token(scopes=scopes, resource=resource) return token async def _request_app_service_token(self, scopes, resource, secret): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py index 798aefe8c5c2..6b1da3a6f8ae 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode_credential.py @@ -50,9 +50,16 @@ async def get_token(self, *scopes, **kwargs): raise ValueError("'get_token' requires at least one scope") token = self._client.get_cached_access_token(scopes) - if token: - return token + if not token: + token = await self._redeem_refresh_token(scopes, **kwargs) + elif self._client.should_refresh(token): + try: + await self._redeem_refresh_token(scopes, **kwargs) + except Exception: # pylint: disable=broad-except + pass + return token + async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]": if not self._refresh_token: self._refresh_token = get_credentials() if not self._refresh_token: diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index daa40c3d4659..b878d4bbf5e8 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,10 +3,11 @@ # Licensed under the MIT License. # ------------------------------------ import functools - +import time from azure.core.exceptions import ClientAuthenticationError -from azure.identity._constants import EnvironmentVariables +from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from azure.identity._internal.aad_client import AadClient +from azure.core.credentials import AccessToken import pytest from msal import TokenCache from six.moves.urllib_parse import urlparse @@ -201,3 +202,24 @@ def send(request, **_): assert transport.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0 + + +def test_should_refresh(): + client = AadClient("test", "test") + now = int(time.time()) + + # do not need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) + should_refresh = client.should_refresh(token) + assert not should_refresh + + # need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + should_refresh = client.should_refresh(token) + assert should_refresh + + # not exceed cool down time, do not refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 + should_refresh = client.should_refresh(token) + assert not should_refresh diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index e43f70bd4369..ab9cd8208809 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -5,10 +5,11 @@ import functools from unittest.mock import Mock, patch from urllib.parse import urlparse - +import time from azure.core.exceptions import ClientAuthenticationError -from azure.identity._constants import EnvironmentVariables +from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from azure.identity.aio._internal.aad_client import AadClient +from azure.core.credentials import AccessToken from msal import TokenCache import pytest @@ -208,3 +209,24 @@ async def send(request, **_): assert transport.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1 assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0 + + +async def test_should_refresh(): + client = AadClient("test", "test") + now = int(time.time()) + + # do not need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) + should_refresh = client.should_refresh(token) + assert not should_refresh + + # need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + should_refresh = client.should_refresh(token) + assert should_refresh + + # not exceed cool down time, do not refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 + should_refresh = client.should_refresh(token) + assert not should_refresh diff --git a/sdk/identity/azure-identity/tests/test_authn_client.py b/sdk/identity/azure-identity/tests/test_authn_client.py index 6732d43cd4dc..c5dbbe41394a 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client.py +++ b/sdk/identity/azure-identity/tests/test_authn_client.py @@ -14,7 +14,7 @@ from azure.core.credentials import AccessToken from azure.identity._authn_client import AuthnClient -from azure.identity._constants import EnvironmentVariables +from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY import pytest from six.moves.urllib_parse import urlparse from helpers import mock_response @@ -233,3 +233,24 @@ def mock_send(request, **kwargs): client.request_token(("scope",)) request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") validate_url(request.url) + + +def test_should_refresh(): + client = AuthnClient(endpoint="http://foo") + now = int(time.time()) + + # do not need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) + should_refresh = client.should_refresh(token) + assert not should_refresh + + # need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + should_refresh = client.should_refresh(token) + assert should_refresh + + # not exceed cool down time, do not refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 + should_refresh = client.should_refresh(token) + assert not should_refresh diff --git a/sdk/identity/azure-identity/tests/test_authn_client_async.py b/sdk/identity/azure-identity/tests/test_authn_client_async.py index ab94c2c236c4..d80367d9b583 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client_async.py +++ b/sdk/identity/azure-identity/tests/test_authn_client_async.py @@ -3,11 +3,13 @@ # Licensed under the MIT License. # ------------------------------------ import asyncio +import time from unittest.mock import Mock, patch from urllib.parse import urlparse import pytest -from azure.identity._constants import EnvironmentVariables +from azure.core.credentials import AccessToken +from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from azure.identity.aio._authn_client import AsyncAuthnClient from helpers import mock_response @@ -35,3 +37,24 @@ def mock_send(request, **kwargs): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send))) await client.request_token(("scope",)) + + +async def test_should_refresh(): + client = AsyncAuthnClient(endpoint="http://foo") + now = int(time.time()) + + # do not need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET + 1) + should_refresh = client.should_refresh(token) + assert not should_refresh + + # need refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + should_refresh = client.should_refresh(token) + assert should_refresh + + # not exceed cool down time, do not refresh + token = AccessToken("token", now + DEFAULT_REFRESH_OFFSET - 1) + client._last_refresh_time = now - DEFAULT_TOKEN_REFRESH_RETRY_DELAY + 1 + should_refresh = client.should_refresh(token) + assert not should_refresh diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index a0b320410bc8..6f22ad7ad8d6 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -94,7 +94,7 @@ def test_cache_refresh_token(): def test_no_obtain_token_if_cached(): expected_token = AccessToken("token", 42) - mock_client = mock.Mock(spec=object) + mock_client = mock.Mock(should_refresh=lambda _: False) mock_client.obtain_token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client.get_cached_access_token = mock.Mock(return_value="VALUE") @@ -106,7 +106,7 @@ def test_no_obtain_token_if_cached(): @pytest.mark.skipif(not sys.platform.startswith("linux"), reason="This test only runs on Linux") def test_segfault(): - from azure.identity._credentials.linux_vscode_adapter import _get_refresh_token + from azure.identity._internal.linux_vscode_adapter import _get_refresh_token _get_refresh_token("test", "test") diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 1f26651d45d8..5207c73641fe 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -97,7 +97,7 @@ async def test_cache_refresh_token(): async def test_no_obtain_token_if_cached(): expected_token = AccessToken("token", 42) - mock_client = mock.Mock(spec=object) + mock_client = mock.Mock(should_refresh=lambda _: False) token_by_refresh_token = mock.Mock(return_value=expected_token) mock_client.obtain_token_by_refresh_token = wrap_in_future(token_by_refresh_token) mock_client.get_cached_access_token = mock.Mock(return_value="VALUE")