Skip to content

Commit

Permalink
[Communication-Chat]Updated async credentials (Azure#17451)
Browse files Browse the repository at this point in the history
* Updated async credentials
* Fix CI
* Fix tests
* Updated utils
* Pylint fix
  • Loading branch information
annatisch authored Mar 23, 2021
1 parent 3acad11 commit e8efaa5
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions
from .utils import _convert_datetime_to_utc_int


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
Expand All @@ -36,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -80,14 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now_as_int() <\
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now_as_int() < self._token.expires_on

@classmethod
def _get_utc_now_as_int(cls):
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime)
return current_utc_datetime_as_int
return get_current_utc_as_int() < self._token.expires_on
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,27 @@
# license information.
# --------------------------------------------------------------------------
from asyncio import Condition, Lock
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Any
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions
from .utils import _convert_datetime_to_utc_int


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
:keyword token_refresher: The token refresher to provide capacity to fetch fresh token
:keyword token_refresher: The async token refresher to provide capacity to fetch fresh token
:raises: TypeError
"""

_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
def __init__(self, token: str, **kwargs: Any):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -36,25 +33,24 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""

if not self._token_refresher or not self._token_expiring():
return self._token

should_this_thread_refresh = False

with self._lock:
async with self._lock:

while self._token_expiring():
if self._some_thread_refreshing:
if self._is_currenttoken_valid():
return self._token

self._wait_till_inprogress_thread_finish_refreshing()
await self._wait_till_inprogress_thread_finish_refreshing()
else:
should_this_thread_refresh = True
self._some_thread_refreshing = True
Expand All @@ -63,34 +59,37 @@ def get_token(self):

if should_this_thread_refresh:
try:
newtoken = self._token_refresher() # pylint:disable=not-callable
newtoken = await self._token_refresher() # pylint:disable=not-callable

with self._lock:
async with self._lock:
self._token = newtoken
self._some_thread_refreshing = False
self._lock.notify_all()
except:
with self._lock:
async with self._lock:
self._some_thread_refreshing = False
self._lock.notify_all()

raise

return self._token

def _wait_till_inprogress_thread_finish_refreshing(self):
async def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.release()
self._lock.acquire()
await self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now_as_int() < self._token.expires_on
return get_current_utc_as_int() < self._token.expires_on

async def close(self) -> None:
pass

async def __aenter__(self):
return self

@classmethod
def _get_utc_now_as_int(cls):
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime_as_int = _convert_datetime_to_utc_int(current_utc_datetime)
return current_utc_datetime_as_int
async def __aexit__(self, *args):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken


def _convert_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())


def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
endpoint = None
Expand Down Expand Up @@ -43,6 +49,13 @@ def get_current_utc_time():
# type: () -> str
return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT"


def get_current_utc_as_int():
# type: () -> int
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
return _convert_datetime_to_utc_int(current_utc_datetime)


def create_access_token(token):
# type: (str) -> azure.core.credentials.AccessToken
"""Creates an instance of azure.core.credentials.AccessToken from a
Expand Down Expand Up @@ -71,6 +84,7 @@ def create_access_token(token):
except ValueError:
raise ValueError(token_parse_err_msg)


def get_authentication_policy(
endpoint, # type: str
credential, # type: TokenCredential or str
Expand Down Expand Up @@ -101,7 +115,3 @@ def get_authentication_policy(

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))

def _convert_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import six
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.exceptions import HttpResponseError
from azure.core.async_paging import AsyncItemPaged

Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(

self._client = AzureCommunicationChatService(
self._endpoint,
authentication_policy=BearerTokenCredentialPolicy(self._credential),
authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential),
sdk_moniker=SDK_MONIKER,
**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import six
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
from azure.core.async_paging import AsyncItemPaged

from .._shared.user_credential_async import CommunicationTokenCredential
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(

self._client = AzureCommunicationChatService(
endpoint,
authentication_policy=BearerTokenCredentialPolicy(self._credential),
authentication_policy=AsyncBearerTokenCredentialPolicy(self._credential),
sdk_moniker=SDK_MONIKER,
**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@
import pytest
import time


def _convert_datetime_to_utc_int(input):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
input_datetime_as_int = epoch - time.mktime(input.timetuple())
return input_datetime_as_int


credential = Mock()
credential.get_token = Mock(return_value=AccessToken(
"some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))
))
async def mock_get_token():
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)))

credential = Mock(get_token=mock_get_token)


@pytest.mark.asyncio
async def test_create_chat_thread():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
import pytest
import time


def _convert_datetime_to_utc_int(input):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
input_datetime_as_int = epoch - time.mktime(input.timetuple())
return input_datetime_as_int

credential = Mock()
credential.get_token = Mock(return_value=AccessToken(
"some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC))
))

async def mock_get_token():
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=TZ_UTC)))

credential = Mock(get_token=mock_get_token)


@pytest.mark.asyncio
async def test_update_topic():
Expand Down

0 comments on commit e8efaa5

Please sign in to comment.