Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use refresh token setting from credential #12278

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]
self._load_token_refresh_options()

@staticmethod
def _enforce_https(request):
Expand Down Expand Up @@ -68,8 +69,14 @@ def _update_headers(headers, token):
@property
def _need_new_token(self):
# type: () -> bool
return not self._token or self._token.expires_on - time.time() < 300

return not self._token or self._token.expires_on - time.time() < self._token_refresh_offset

def _load_token_refresh_options(self):
try:
token_refresh_options = self._credential.get_token_refresh_options()
self._token_refresh_offset = int(token_refresh_options.get("token_refresh_offset", 300))
except Exception: # pylint: disable=broad-except
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just catch AttributeError (if get_token_refresh_options is not defined), and let a potential "NotANumber" throw from the int conversion

self._token_refresh_offset = 300

class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
"""Adds a bearer token Authorization header to requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from unittest.mock import Mock

from azure.core.credentials import AccessToken
from azure.core.credentials import AccessToken, AzureKeyCredential
from azure.core.exceptions import AzureError, ServiceRequestError
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy
Expand Down Expand Up @@ -160,3 +160,47 @@ def get_completed_future(result=None):
fut = asyncio.Future()
fut.set_result(result)
return fut


@pytest.mark.asyncio
async def test_bearer_policy_refresh_option():
def get_token_refresh_options():
return {"token_refresh_offset": 100}
credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 100


@pytest.mark.asyncio
async def test_bearer_policy_refresh_option():
def get_token_refresh_options():
return {"token_refresh_offset": 100}
credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 100


@pytest.mark.asyncio
async def test_bearer_policy_no_refresh_option_method():
policy = AsyncBearerTokenCredentialPolicy(AzureKeyCredential("test"), "scope")
assert policy._token_refresh_offset == 300


@pytest.mark.asyncio
async def test_bearer_policy_refresh_option_method_return_none():
def get_token_refresh_options():
return None

credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 300


@pytest.mark.asyncio
async def test_bearer_policy_refresh_option_method_return_empty():
def get_token_refresh_options():
return dict()

credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 300
31 changes: 31 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,34 @@ def test_azure_key_credential_updates():
api_key = "new"
credential.update(api_key)
assert credential.key == api_key


def test_bearer_policy_refresh_option():
def get_token_refresh_options():
return {"token_refresh_offset": 100}
credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = BearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 100


def test_bearer_policy_no_refresh_option_method():
policy = BearerTokenCredentialPolicy(AzureKeyCredential("test"), "scope")
assert policy._token_refresh_offset == 300


def test_bearer_policy_refresh_option_method_return_none():
def get_token_refresh_options():
return None

credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = BearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 300


def test_bearer_policy_refresh_option_method_return_empty():
def get_token_refresh_options():
return dict()

credential = Mock(get_token_refresh_options=get_token_refresh_options)
policy = BearerTokenCredentialPolicy(credential, "scope")
assert policy._token_refresh_offset == 300