diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py index 9343975f5925..9b5f17dcc95d 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py @@ -4,16 +4,15 @@ # license information. # -------------------------------------------------------------------------- from threading import Lock, Condition -from datetime import datetime, timedelta +from datetime import timedelta from typing import ( # pylint: disable=unused-import cast, Tuple, ) -from msrest.serialization import TZ_UTC - +from .utils import get_current_utc_as_int from .user_token_refresh_options import CommunicationTokenRefreshOptions -from .utils import _convert_datetime_to_utc_int + class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. @@ -36,8 +35,8 @@ def __init__(self, self._lock = Condition(Lock()) self._some_thread_refreshing = False - def get_token(self): - # type () -> ~azure.core.credentials.AccessToken + def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument + # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ @@ -80,14 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self): self._lock.acquire() def _token_expiring(self): - return self._token.expires_on - self._get_utc_now_as_int() <\ + return self._token.expires_on - get_current_utc_as_int() <\ timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() def _is_currenttoken_valid(self): - return self._get_utc_now_as_int() < self._token.expires_on - - @classmethod - def _get_utc_now_as_int(cls): - current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC) - current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime) - return current_utc_datetime_as_int + return get_current_utc_as_int() < self._token.expires_on diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py index 24252a783961..52a99e7a4b6a 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py @@ -4,30 +4,27 @@ # license information. # -------------------------------------------------------------------------- from asyncio import Condition, Lock -from datetime import datetime, timedelta +from datetime import timedelta from typing import ( # pylint: disable=unused-import cast, Tuple, + Any ) -from msrest.serialization import TZ_UTC - +from .utils import get_current_utc_as_int from .user_token_refresh_options import CommunicationTokenRefreshOptions -from .utils import _convert_datetime_to_utc_int + class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token + :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token :raises: TypeError """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - def __init__(self, - token, # type: str - **kwargs - ): + def __init__(self, token: str, **kwargs: Any): token_refresher = kwargs.pop('token_refresher', None) communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, token_refresher=token_refresher) @@ -36,25 +33,24 @@ def __init__(self, self._lock = Condition(Lock()) self._some_thread_refreshing = False - def get_token(self): - # type () -> ~azure.core.credentials.AccessToken + async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument + # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): return self._token should_this_thread_refresh = False - with self._lock: + async with self._lock: while self._token_expiring(): if self._some_thread_refreshing: if self._is_currenttoken_valid(): return self._token - self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_inprogress_thread_finish_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -63,14 +59,14 @@ def get_token(self): if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable + newtoken = await self._token_refresher() # pylint:disable=not-callable - with self._lock: + async with self._lock: self._token = newtoken self._some_thread_refreshing = False self._lock.notify_all() except: - with self._lock: + async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() @@ -78,19 +74,22 @@ def get_token(self): return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + async def _wait_till_inprogress_thread_finish_refreshing(self): self._lock.release() - self._lock.acquire() + await self._lock.acquire() def _token_expiring(self): - return self._token.expires_on - self._get_utc_now_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() + return self._token.expires_on - get_current_utc_as_int() <\ + timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() def _is_currenttoken_valid(self): - return self._get_utc_now_as_int() < self._token.expires_on + return get_current_utc_as_int() < self._token.expires_on + + async def close(self) -> None: + pass + + async def __aenter__(self): + return self - @classmethod - def _get_utc_now_as_int(cls): - current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC) - current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime) - return current_utc_datetime_as_int + async def __aexit__(self, *args): + await self.close() diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py index a4395e12643c..f7f4046b0b74 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py @@ -15,6 +15,12 @@ from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + +def _convert_datetime_to_utc_int(expires_on): + epoch = time.mktime(datetime(1970, 1, 1).timetuple()) + return epoch-time.mktime(expires_on.timetuple()) + + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] endpoint = None @@ -43,6 +49,13 @@ def get_current_utc_time(): # type: () -> str return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + +def get_current_utc_as_int(): + # type: () -> int + current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC) + return _convert_datetime_to_utc_int(current_utc_datetime) + + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -71,6 +84,7 @@ def create_access_token(token): except ValueError: raise ValueError(token_parse_err_msg) + def get_authentication_policy( endpoint, # type: str credential, # type: TokenCredential or str @@ -101,7 +115,3 @@ def get_authentication_policy( raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy" "or a token credential from azure.identity".format(type(credential))) - -def _convert_datetime_to_utc_int(expires_on): - epoch = time.mktime(datetime(1970, 1, 1).timetuple()) - return epoch-time.mktime(expires_on.timetuple()) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py index 44c4bf0b91c4..7b9fc6d4a2a2 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_client_async.py @@ -16,7 +16,7 @@ import six from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async -from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.exceptions import HttpResponseError from azure.core.async_paging import AsyncItemPaged @@ -86,7 +86,7 @@ def __init__( self._client = AzureCommunicationChatService( self._endpoint, - authentication_policy=BearerTokenCredentialPolicy(self._credential), + authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential), sdk_moniker=SDK_MONIKER, **kwargs) diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py index 758f996d7e71..f2388f972185 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/aio/_chat_thread_client_async.py @@ -15,7 +15,7 @@ import six from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async -from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.async_paging import AsyncItemPaged from .._shared.user_credential_async import CommunicationTokenCredential @@ -103,7 +103,7 @@ def __init__( self._client = AzureCommunicationChatService( endpoint, - authentication_policy=BearerTokenCredentialPolicy(self._credential), + authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential), sdk_moniker=SDK_MONIKER, **kwargs) diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_client_async.py b/sdk/communication/azure-communication-chat/tests/test_chat_client_async.py index 596c7f5c2145..568aed5faa56 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_client_async.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_client_async.py @@ -25,16 +25,18 @@ import pytest import time + def _convert_datetime_to_utc_int(input): epoch = time.mktime(datetime(1970, 1, 1).timetuple()) input_datetime_as_int = epoch - time.mktime(input.timetuple()) return input_datetime_as_int -credential = Mock() -credential.get_token = Mock(return_value=AccessToken( - "some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)) -)) +async def mock_get_token(): + return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))) + +credential = Mock(get_token=mock_get_token) + @pytest.mark.asyncio async def test_create_chat_thread(): diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_async.py b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_async.py index 2fa9e9488d99..3b1c6e8cb4f2 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_async.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_async.py @@ -25,15 +25,18 @@ import pytest import time + def _convert_datetime_to_utc_int(input): epoch = time.mktime(datetime(1970, 1, 1).timetuple()) input_datetime_as_int = epoch - time.mktime(input.timetuple()) return input_datetime_as_int -credential = Mock() -credential.get_token = Mock(return_value=AccessToken( - "some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)) -)) + +async def mock_get_token(): + return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))) + +credential = Mock(get_token=mock_get_token) + @pytest.mark.asyncio async def test_update_topic():