Skip to content

Commit

Permalink
token refresh offset (#12136)
Browse files Browse the repository at this point in the history
* token refresh offset
  • Loading branch information
xiangyan99 authored Jul 17, 2020
1 parent 8a41f87 commit 117a6f5
Show file tree
Hide file tree
Showing 23 changed files with 290 additions and 82 deletions.
20 changes: 18 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UserAgentPolicy,
)
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
from ._constants import AZURE_CLI_CLIENT_ID
from ._constants import AZURE_CLI_CLIENT_ID, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from ._internal import get_default_authority, normalize_authority
from ._internal.user_agent import USER_AGENT

Expand Down Expand Up @@ -65,17 +65,32 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl
authority = normalize_authority(authority) if authority else get_default_authority()
self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token"))
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
self._last_refresh_time = 0

@property
def auth_url(self):
return self._auth_url

def should_refresh(self, token):
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_retry_delay:
return False
return True

def get_cached_token(self, scopes):
# type: (Iterable[str]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on > int(time.time()):
return AccessToken(token["secret"], expires_on)
return None

Expand Down Expand Up @@ -217,6 +232,7 @@ def request_token(
# type: (...) -> AccessToken
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
request_time = int(time.time())
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
response = self._pipeline.run(request, stream=False, **kwargs)
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
AZURE_VSCODE_CLIENT_ID = "aebc6443-996d-45c2-90f0-388ff96faa56"
VSCODE_CREDENTIALS_SECTION = "VS Code Azure"
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30


class KnownAuthorities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def get_token(self, *scopes, **kwargs):
self._authorization_code = None # auth codes are single-use
return token

token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs)
token = self._client.get_cached_access_token(scopes)
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass

if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_token(self, *scopes, **kwargs):
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,28 +170,37 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

token = self._client.get_cached_token(scopes)
if not token:
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)

token = self._refresh_token(*scopes)
elif self._client.should_refresh(token):
try:
token = self._client.request_token(scopes, method="GET", params=params)
except HttpResponseError as ex:
# 400 in response to a token request indicates managed identity is disabled,
# or the identity with the specified client_id is not available
if ex.status_code == 400:
self._endpoint_available = False
message = "ManagedIdentityCredential authentication unavailable. "
if self._identity_config:
message += "The requested identity has not been assigned to this resource."
else:
message += "No identity has been assigned to this resource."
six.raise_from(CredentialUnavailableError(message=message), ex)

# any other error is unexpected
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
token = self._refresh_token(*scopes)
except Exception: # pylint: disable=broad-except
pass

return token

def _refresh_token(self, *scopes):
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)

try:
token = self._client.request_token(scopes, method="GET", params=params)
except HttpResponseError as ex:
# 400 in response to a token request indicates managed identity is disabled,
# or the identity with the specified client_id is not available
if ex.status_code == 400:
self._endpoint_available = False
message = "ManagedIdentityCredential authentication unavailable. "
if self._identity_config:
message += "The requested identity has not been assigned to this resource."
else:
message += "No identity has been assigned to this resource."
six.raise_from(CredentialUnavailableError(message=message), ex)

# any other error is unexpected
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
return token


Expand Down Expand Up @@ -227,16 +236,25 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

token = self._client.get_cached_token(scopes)
if not token:
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
if secret:
# MSI_ENDPOINT and MSI_SECRET set -> App Service
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
else:
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
token = self._request_legacy_token(scopes=scopes, resource=resource)
token = self._refresh_token(*scopes)
elif self._client.should_refresh(token):
try:
token = self._refresh_token(*scopes)
except Exception: # pylint: disable=broad-except
pass
return token

def _refresh_token(self, *scopes):
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
if secret:
# MSI_ENDPOINT and MSI_SECRET set -> App Service
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
else:
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
token = self._request_legacy_token(scopes=scopes, resource=resource)
return token

def _request_app_service_token(self, scopes, resource, secret):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from .._internal.aad_client import AadClient

if sys.platform.startswith("win"):
from .win_vscode_adapter import get_credentials
from .._internal.win_vscode_adapter import get_credentials
elif sys.platform.startswith("darwin"):
from .macos_vscode_adapter import get_credentials
from .._internal.macos_vscode_adapter import get_credentials
else:
from .linux_vscode_adapter import get_credentials
from .._internal.linux_vscode_adapter import get_credentials

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
Expand Down Expand Up @@ -47,9 +47,17 @@ def get_token(self, *scopes, **kwargs):

token = self._client.get_cached_access_token(scopes)

if token:
return token
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority
from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET

try:
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -48,13 +49,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._cache = cache or TokenCache()
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
self._last_refresh_time = 0

def get_cached_access_token(self, scopes, query=None):
# type: (Iterable[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on > int(time.time()):
return AccessToken(token["secret"], expires_on)
return None

Expand All @@ -63,6 +67,19 @@ def get_cached_refresh_tokens(self, scopes):
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def should_refresh(self, token):
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_retry_delay:
return False
return True


@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass
Expand All @@ -85,6 +102,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time

content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def request_token( # pylint:disable=invalid-overridden-method
) -> AccessToken:
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
request_time = int(time.time())
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
response = await self._pipeline.run(request, stream=False, **kwargs)
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)

elif self._client.should_refresh(token):
try:
await self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Loading

0 comments on commit 117a6f5

Please sign in to comment.