Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve token refresh logic #277

Merged
merged 5 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions samples/asynctests/test_azure_event_hubs_send_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import os
import logging
import asyncio
import time
from datetime import timedelta
import pytest
import sys
import types
import collections

import uamqp
from uamqp import types as uamqp_types, utils, authentication
from uamqp import types as uamqp_types, utils, authentication, constants

_AccessToken = collections.namedtuple("AccessToken", "token expires_on")


def get_logger(level):
Expand Down Expand Up @@ -237,6 +241,60 @@ async def test_event_hubs_send_large_message_after_socket_lost(live_eventhub_con
finally:
await send_client.close_async()


@pytest.mark.asyncio
async def test_event_hubs_send_override_token_refresh_window(live_eventhub_config):
uri = "sb://{}/{}".format(live_eventhub_config['hostname'], live_eventhub_config['event_hub'])
target = "amqps://{}/{}/Partitions/0".format(live_eventhub_config['hostname'], live_eventhub_config['event_hub'])
token = None

async def get_token():
nonlocal token
return _AccessToken(token, expiry)

jwt_auth = authentication.JWTTokenAsync(
uri,
uri,
get_token,
override_token_refresh_window=300 # set refresh window to be 5 mins
)

send_client = uamqp.SendClientAsync(target, auth=jwt_auth, debug=False)

# use token of which the valid remaining time < refresh window
expiry = int(time.time()) + (60 * 4 + 30) # 4.5 minutes
token = utils.create_sas_token(
live_eventhub_config['key_name'].encode(),
live_eventhub_config['access_key'].encode(),
uri.encode(),
expiry=timedelta(minutes=4, seconds=30)
)

for _ in range(3):
message = uamqp.message.Message(body='Hello World')
await send_client.send_message_async(message)

auth_status = constants.CBSAuthStatus(jwt_auth._cbs_auth.get_status())
assert auth_status == constants.CBSAuthStatus.RefreshRequired

# update token, the valid remaining time > refresh window
expiry = int(time.time()) + (60 * 5 + 30) # 5.5 minutes
token = utils.create_sas_token(
live_eventhub_config['key_name'].encode(),
live_eventhub_config['access_key'].encode(),
uri.encode(),
expiry=timedelta(minutes=5, seconds=30)
)

for _ in range(3):
message = uamqp.message.Message(body='Hello World')
await send_client.send_message_async(message)

auth_status = constants.CBSAuthStatus(jwt_auth._cbs_auth.get_status())
assert auth_status == constants.CBSAuthStatus.Ok
await send_client.close_async()


if __name__ == '__main__':
config = {}
config['hostname'] = os.environ['EVENT_HUB_HOSTNAME']
Expand All @@ -247,4 +305,4 @@ async def test_event_hubs_send_large_message_after_socket_lost(live_eventhub_con
config['partition'] = "0"

loop = asyncio.get_event_loop()
loop.run_until_complete(test_event_hubs_send_large_message_after_socket_lost(config))
loop.run_until_complete(test_event_hubs_send_override_token_refresh_window(config))
58 changes: 56 additions & 2 deletions samples/test_azure_event_hubs_send.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from datetime import timedelta
import types
import pytest
import collections

import uamqp
from uamqp import types as uamqp_types, utils, authentication
from uamqp import types as uamqp_types, utils, authentication, constants

_AccessToken = collections.namedtuple("AccessToken", "token expires_on")

def get_logger(level):
uamqp_logger = logging.getLogger("uamqp")
Expand Down Expand Up @@ -264,6 +266,58 @@ def test_event_hubs_send_large_message_after_socket_lost(live_eventhub_config):
finally:
send_client.close()


def test_event_hubs_send_override_token_refresh_window(live_eventhub_config):
uri = "sb://{}/{}".format(live_eventhub_config['hostname'], live_eventhub_config['event_hub'])
target = "amqps://{}/{}/Partitions/0".format(live_eventhub_config['hostname'], live_eventhub_config['event_hub'])
token = [None]

def get_token():
return _AccessToken(token[0], expiry)

jwt_auth = authentication.JWTTokenAuth(
uri,
uri,
get_token,
override_token_refresh_window=300 # set refresh window to be 5 mins
)

send_client = uamqp.SendClient(target, auth=jwt_auth, debug=False)

# use token of which the valid remaining time < refresh window
expiry = int(time.time()) + (60 * 4 + 30) # 4.5 minutes
token[0] = utils.create_sas_token(
live_eventhub_config['key_name'].encode(),
live_eventhub_config['access_key'].encode(),
uri.encode(),
expiry=timedelta(minutes=4, seconds=30)
)

for _ in range(3):
message = uamqp.message.Message(body='Hello World')
send_client.send_message(message)

auth_status = constants.CBSAuthStatus(jwt_auth._cbs_auth.get_status())
assert auth_status == constants.CBSAuthStatus.RefreshRequired

# update token, the valid remaining time > refresh window
expiry = int(time.time()) + (60 * 5 + 30) # 5.5 minutes
token[0] = utils.create_sas_token(
live_eventhub_config['key_name'].encode(),
live_eventhub_config['access_key'].encode(),
uri.encode(),
expiry=timedelta(minutes=5, seconds=30)
)

for _ in range(3):
message = uamqp.message.Message(body='Hello World')
send_client.send_message(message)

auth_status = constants.CBSAuthStatus(jwt_auth._cbs_auth.get_status())
assert auth_status == constants.CBSAuthStatus.Ok
send_client.close()


if __name__ == '__main__':
config = {}
config['hostname'] = os.environ['EVENT_HUB_HOSTNAME']
Expand All @@ -273,4 +327,4 @@ def test_event_hubs_send_large_message_after_socket_lost(live_eventhub_config):
config['consumer_group'] = "$Default"
config['partition'] = "0"

test_event_hubs_send_large_message_after_socket_lost(config)
test_event_hubs_send_override_token_refresh_window(config)
11 changes: 7 additions & 4 deletions src/cbs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cdef class CBSTokenAuth(object):
cdef const char* connection_id
cdef cSession _session

def __cinit__(self, const char* audience, const char* token_type, const char* token, stdint.uint64_t expires_at, cSession session, stdint.uint64_t timeout, const char* connection_id):
def __cinit__(self, const char* audience, const char* token_type, const char* token, stdint.uint64_t expires_at, cSession session, stdint.uint64_t timeout, const char* connection_id, stdint.uint64_t override_token_refresh_window):
self.state = AUTH_STATUS_IDLE
self.audience = audience
self.token_type = token_type
Expand All @@ -61,9 +61,12 @@ cdef class CBSTokenAuth(object):
self.auth_timeout = timeout
self.connection_id = connection_id
self._token_put_time = 0
current_time = int(time.time())
remaining_time = expires_at - current_time
self._refresh_window = int(float(remaining_time) * 0.1)
if override_token_refresh_window > 0:
self._refresh_window = override_token_refresh_window
else:
current_time = int(time.time())
remaining_time = expires_at - current_time
self._refresh_window = int(float(remaining_time) * 0.1)
self._cbs_handle = c_cbs.cbs_create(<c_session.SESSION_HANDLE>session._c_value)
self._session = session
if <void*>self._cbs_handle == NULL:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

from uamqp.message import MessageProperties, MessageHeader, Message, constants, errors, c_uamqp
import pickle
import pytest
Expand Down
19 changes: 17 additions & 2 deletions uamqp/authentication/cbs_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def create_authenticator(self, connection, debug=False, **kwargs):
int(self.expires_at),
self._session._session, # pylint: disable=protected-access
self.timeout,
self._connection.container_id)
self._connection.container_id,
self._override_token_refresh_window
)
self._cbs_auth.set_trace(debug)
except ValueError:
self._session.destroy()
Expand Down Expand Up @@ -137,7 +139,14 @@ def handle_token(self):
_logger.info("Token on connection %r will expire soon - attempting to refresh.",
self._connection.container_id)
self.update_token()
self._cbs_auth.refresh(self.token, int(self.expires_at))
if self.token != self._prev_token:
self._cbs_auth.refresh(self.token, int(self.expires_at))
else:
_logger.info(
"The newly acquired token on connection %r is the same as the previous one,"
" will keep attempting to refresh",
self._connection.container_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this log statement going to get very noisy? Or is it very rare that token == previous_token after calling update_token?

Copy link
Contributor Author

@yunhaoling yunhaoling Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it would be very noisy in the case token == previous_token (which happens now when there's a gap between the client refresh window (uAMQP 6 mins to refresh) and service refresh window (AAD service 5 mins refresh))

So once we have fixed the issue in uamqp/eh/sb, changing default refresh window to 5 mins, it should be rare -- at least not worse than what we currently have.


I think what uamqp is doing technically is not wrong -- it's doing its best to refresh token, however it's doing a bit excessively, retry without backoff. So ultimately we might want to introduce kinda backoff to token refresh to let it not throttle.

with that being said, I think we could keep this "potential" noisy logging for now until user complains.

)
elif auth_status == constants.CBSAuthStatus.Idle:
self._cbs_auth.authenticate()
in_progress = True
Expand Down Expand Up @@ -231,6 +240,8 @@ def __init__(self, audience, uri, token,
**kwargs): # pylint: disable=no-member
self._retry_policy = retry_policy
self._encoding = encoding
self._override_token_refresh_window = kwargs.pop("override_token_refresh_window", 0)
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved
self._prev_token = None
self.uri = uri
parsed = compat.urlparse(uri) # pylint: disable=no-member

Expand Down Expand Up @@ -259,6 +270,7 @@ def update_token(self):
encoded_uri = compat.quote_plus(self.uri).encode(self._encoding) # pylint: disable=no-member
encoded_key = compat.quote_plus(self.username).encode(self._encoding) # pylint: disable=no-member
self.expires_at = time.time() + self.expires_in.seconds
self._prev_token = self.token
self.token = utils.create_sas_token(
encoded_key,
self.password.encode(self._encoding),
Expand Down Expand Up @@ -399,6 +411,8 @@ def __init__(self, audience, uri,
**kwargs): # pylint: disable=no-member
self._retry_policy = retry_policy
self._encoding = encoding
self._override_token_refresh_window = kwargs.pop("override_token_refresh_window", 0)
self._prev_token = None
self.uri = uri
parsed = compat.urlparse(uri) # pylint: disable=no-member

Expand All @@ -425,4 +439,5 @@ def create_authenticator(self, connection, debug=False, **kwargs):
def update_token(self):
access_token = self.get_token()
self.expires_at = access_token.expires_on
self._prev_token = self.token
self.token = self._encode(access_token.token)
16 changes: 14 additions & 2 deletions uamqp/authentication/cbs_auth_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ async def create_authenticator_async(self, connection, debug=False, loop=None, *
int(self.expires_at),
self._session._session, # pylint: disable=protected-access
self.timeout,
self._connection.container_id)
self._connection.container_id,
self._override_token_refresh_window
)
self._cbs_auth.set_trace(debug)
except ValueError:
await self._session.destroy_async()
Expand Down Expand Up @@ -132,7 +134,14 @@ async def handle_token_async(self):
_logger.info("Token on connection %r will expire soon - attempting to refresh.",
self._connection.container_id)
await self.update_token()
self._cbs_auth.refresh(self.token, int(self.expires_at))
if self.token != self._prev_token:
self._cbs_auth.refresh(self.token, int(self.expires_at))
else:
_logger.info(
"The newly acquired token on connection %r is the same as the previous one,"
" will keep attempting to refresh",
self._connection.container_id
)
elif auth_status == constants.CBSAuthStatus.Idle:
self._cbs_auth.authenticate()
in_progress = True
Expand Down Expand Up @@ -260,6 +269,8 @@ def __init__(self, audience, uri,
**kwargs): # pylint: disable=no-member
self._retry_policy = retry_policy
self._encoding = encoding
self._override_token_refresh_window = kwargs.pop("override_token_refresh_window", 0)
self._prev_token = None
self.uri = uri
parsed = compat.urlparse(uri) # pylint: disable=no-member

Expand All @@ -285,4 +296,5 @@ async def create_authenticator_async(self, connection, debug=False, loop=None, *
async def update_token(self):
access_token = await self.get_token()
self.expires_at = access_token.expires_on
self._prev_token = self.token
self.token = self._encode(access_token.token)