Skip to content

Commit

Permalink
[Messaging] Stop Unnecessary Refresh of JWT Tokens (#23129)
Browse files Browse the repository at this point in the history
The focus of these changes is to cache JWT tokens used for authorization
of AMQP connections and links to avoid making unnecesary requests due to network
changes while the token is not near expiration.

For features like Service Bus sessions, links can be opened/closed frequently,
leading to a non-trivial amount of requests to acquire tokens that are likely to
remain valid for the duration of their use.

Also included is the addition of a small amount of random jitter when calculating
the time that AMQP authorization should be resent; this is intended to reduce
contention when a new token needs to be acquired.
  • Loading branch information
jsquire authored Aug 16, 2021
1 parent 71e5e6a commit 774caf6
Show file tree
Hide file tree
Showing 9 changed files with 785 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ internal class AmqpConnectionScope : IDisposable
/// <summary>The string formatting mask to apply to the service endpoint to publish events for a given partition.</summary>
private const string PartitionProducerPathSuffixMask = "{0}/Partitions/{1}";

/// <summary>The seed to use for initializing random number generated for a given thread-specific instance.</summary>
private static int s_randomSeed = Environment.TickCount;

/// <summary>The random number generator to use for a specific thread.</summary>
private static readonly ThreadLocal<Random> RandomNumberGenerator = new ThreadLocal<Random>(() => new Random(Interlocked.Increment(ref s_randomSeed)), false);

/// <summary>Indicates whether or not this instance has been disposed.</summary>
private volatile bool _disposed;

Expand All @@ -61,14 +67,22 @@ internal class AmqpConnectionScope : IDisposable
/// than the expected expiration by this amount.
/// </summary>
///
private static TimeSpan AuthorizationRefreshBuffer { get; } = TimeSpan.FromMinutes(10);
private static TimeSpan AuthorizationRefreshBuffer { get; } = TimeSpan.FromMinutes(7);

/// <summary>
/// The amount of seconds to use as the basis for calculating a random jitter amount
/// when refreshing token authorization. This is intended to ensure that multiple
/// resources using the authorization do not all attempt to refresh at the same moment.
/// </summary>
///
private static int AuthorizationBaseJitterSeconds { get; } = 30;

/// <summary>
/// The minimum amount of time for authorization to be refreshed; any calculations that
/// call for refreshing more frequently will be substituted with this value.
/// </summary>
///
private static TimeSpan MinimumAuthorizationRefresh { get; } = TimeSpan.FromMinutes(2);
private static TimeSpan MinimumAuthorizationRefresh { get; } = TimeSpan.FromMinutes(3);

/// <summary>
/// The maximum amount of time to allow before authorization is refreshed; any calculations
Expand All @@ -88,6 +102,14 @@ internal class AmqpConnectionScope : IDisposable
///
private static TimeSpan AuthorizationRefreshTimeout { get; } = TimeSpan.FromMinutes(3);

/// <summary>
/// The amount of buffer to apply when considering an authorization token
/// to be expired. The token's actual expiration will be decreased by this
/// amount, ensuring that it is renewed before it has expired.
/// </summary>
///
private static TimeSpan AuthorizationTokenExpirationBuffer { get; } = AuthorizationRefreshBuffer.Add(TimeSpan.FromMinutes(2));

/// <summary>
/// The recommended timeout to associate with an AMQP session. It is recommended that this
/// interval be used when creating or opening AMQP links and related constructs.
Expand Down Expand Up @@ -237,8 +259,8 @@ public AmqpConnectionScope(Uri serviceEndpoint,
SendBufferSizeInBytes = sendBufferSizeBytes;
ReceiveBufferSizeInBytes = receiveBufferSizeBytes;
CertificateValidationCallback = certificateValidationCallback;
TokenProvider = new CbsTokenProvider(new EventHubTokenCredential(credential), OperationCancellationSource.Token);
Id = identifier ?? $"{ eventHubName }-{ Guid.NewGuid().ToString("D", CultureInfo.InvariantCulture).Substring(0, 8) }";
TokenProvider = new CbsTokenProvider(new EventHubTokenCredential(credential), AuthorizationTokenExpirationBuffer, OperationCancellationSource.Token);

Task<AmqpConnection> connectionFactory(TimeSpan timeout) => CreateAndOpenConnectionAsync(AmqpVersion, ServiceEndpoint, ConnectionEndpoint, Transport, Proxy, SendBufferSizeInBytes, ReceiveBufferSizeInBytes, CertificateValidationCallback, Id, timeout);
ActiveConnection = new FaultTolerantAmqpObject<AmqpConnection>(connectionFactory, CloseConnection);
Expand Down Expand Up @@ -902,7 +924,15 @@ protected virtual TimeSpan CalculateLinkAuthorizationRefreshInterval(DateTime ex
{
currentTimeUtc ??= DateTime.UtcNow;

var refreshDueInterval = (expirationTimeUtc.Subtract(AuthorizationRefreshBuffer)).Subtract(currentTimeUtc.Value);
// Calculate the interval for when refreshing authorization should take place;
// the refresh is based on the time that the credential expires, less a buffer to
// allow for clock skew. A random number of seconds is added as jitter, to prevent
// multiple resources using the same token from all requesting a refresh at the same moment.

var refreshDueInterval = expirationTimeUtc
.Subtract(AuthorizationRefreshBuffer)
.Subtract(currentTimeUtc.Value)
.Subtract(TimeSpan.FromSeconds(RandomNumberGenerator.Value.NextDouble() * AuthorizationBaseJitterSeconds));

return refreshDueInterval switch
{
Expand Down
116 changes: 112 additions & 4 deletions sdk/eventhub/Azure.Messaging.EventHubs/src/Amqp/CbsTokenProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Azure.Messaging.EventHubs.Amqp
///
/// <seealso cref="Microsoft.Azure.Amqp.ICbsTokenProvider" />
///
internal sealed class CbsTokenProvider : ICbsTokenProvider
internal sealed class CbsTokenProvider : ICbsTokenProvider, IDisposable
{
/// <summary>The type to consider a token if it is based on an Event Hubs shared access signature.</summary>
private const string SharedAccessTokenType = "servicebus.windows.net:sastoken";
Expand All @@ -31,27 +31,48 @@ internal sealed class CbsTokenProvider : ICbsTokenProvider
/// <summary>The credential used to generate access tokens.</summary>
private readonly EventHubTokenCredential Credential;

/// <summary>The primitive for synchronizing access when an authorization token is being acquired.</summary>
private readonly SemaphoreSlim TokenAcquireGuard;

/// <summary>The amount of buffer to when evaluating token expiration; the token's expiration date will be adjusted earlier by this amount.</summary>
private readonly TimeSpan TokenExpirationBuffer;

/// <summary>The cancellation token to consider when making requests.</summary>
private readonly CancellationToken CancellationToken;

/// <summary>The JWT-based <see cref="CbsToken" /> that is currently cached for authorization.</summary>
private CbsToken _cachedJwtToken;

/// <summary>
/// Initializes a new instance of the <see cref="CbsTokenProvider"/> class.
/// </summary>
///
/// <param name="credential">The credential to use for access token generation.</param>
/// <param name="tokenExpirationBuffer">The amount of time to buffer expiration</param>
/// <param name="cancellationToken">The cancellation token to consider when making requests.</param>
///
public CbsTokenProvider(EventHubTokenCredential credential,
TimeSpan tokenExpirationBuffer,
CancellationToken cancellationToken)
{
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNegative(tokenExpirationBuffer, nameof(tokenExpirationBuffer));

Credential = credential;
TokenExpirationBuffer = tokenExpirationBuffer;
CancellationToken = cancellationToken;

TokenType = (credential.IsSharedAccessCredential)
? SharedAccessTokenType
: JsonWebTokenType;

// Tokens are only cached for JWT-based credentials; no need
// to instantiate the semaphore if no caching is taking place.

if (!credential.IsSharedAccessCredential)
{
TokenAcquireGuard = new SemaphoreSlim(1, 1);
}
}

/// <summary>
Expand All @@ -62,14 +83,101 @@ public CbsTokenProvider(EventHubTokenCredential credential,
/// <param name="namespaceAddress">The address of the namespace to be authorized.</param>
/// <param name="appliesTo">The resource to which the token should apply.</param>
/// <param name="requiredClaims">The set of claims that are required for authorization.</param>
///
/// <returns>The token to use for authorization.</returns>
///
public async Task<CbsToken> GetTokenAsync(Uri namespaceAddress,
string appliesTo,
string[] requiredClaims)
string[] requiredClaims) =>
Credential switch
{
_ when Credential.IsSharedAccessCredential => await AcquireSharedAccessTokenAsync().ConfigureAwait(false),
_ => await AcquireJwtTokenAsync().ConfigureAwait(false)
};

/// <summary>
/// Performs the task needed to clean up resources used by the <see cref="CbsTokenProvider" />.
/// </summary>
///
public void Dispose() => TokenAcquireGuard?.Dispose();

/// <summary>
/// Acquires a token for a JWT-type credential.
/// </summary>
///
/// <returns>The <see cref="CbsToken"/> to use for authorization.</returns>
///
private async Task<CbsToken> AcquireJwtTokenAsync()
{
AccessToken token = await Credential.GetTokenUsingDefaultScopeAsync(CancellationToken).ConfigureAwait(false);
return new CbsToken(token.Token, TokenType, token.ExpiresOn.UtcDateTime);
var token = Volatile.Read(ref _cachedJwtToken);
var guardAcquired = false;

// If the token is expiring, attempt to acquire the semaphore needed to
// refresh it.

if (IsNearingExpiration(token))
{
try
{
// Attempt to acquire the semaphore synchronously, in case it is not currently
// held.

if (!TokenAcquireGuard.Wait(0, CancellationToken.None))
{
await TokenAcquireGuard.WaitAsync(CancellationToken).ConfigureAwait(false);
}

guardAcquired = true;

// Because the token may have been refreshed while waiting for the semaphore,
// check expiration again to avoid making an necessary request. Because the token
// is only updated when the semaphore is held, there's no need to perform a volatile
// read for this refresh.

token = _cachedJwtToken;

if (IsNearingExpiration(token))
{
var accessToken = await Credential.GetTokenUsingDefaultScopeAsync(CancellationToken).ConfigureAwait(false);
token = new CbsToken(accessToken.Token, TokenType, accessToken.ExpiresOn.UtcDateTime);

_cachedJwtToken = token;
}
}
finally
{
if (guardAcquired)
{
TokenAcquireGuard.Release();
}
}
}

return token;
}

/// <summary>
/// Acquires a token for a shared access credential.
/// </summary>
///
/// <returns>The <see cref="CbsToken"/> to use for authorization.</returns>
///
private async Task<CbsToken> AcquireSharedAccessTokenAsync()
{
var accessToken = await Credential.GetTokenUsingDefaultScopeAsync(CancellationToken).ConfigureAwait(false);
return new CbsToken(accessToken.Token, TokenType, accessToken.ExpiresOn.UtcDateTime);
}

/// <summary>
/// Determines whether the specified token is nearing its expiration and a new
/// token should be acquired.
/// </summary>
///
/// <param name="token">The token to consider.</param>
///
/// <returns><c>true</c> if the specified token is near expiring; otherwise, <c>false</c>.</returns>
///
private bool IsNearingExpiration(CbsToken token) =>
((token == null) || token.ExpiresAtUtc.Subtract(TokenExpirationBuffer) <= DateTimeOffset.UtcNow);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2109,7 +2109,7 @@ public async Task RequestAuthorizationUsingCbsAsyncRespectsTheConnectionClosing(
var idleTimeout = TimeSpan.FromSeconds(30);
var mockCredential = new Mock<TokenCredential>();
var mockEventHubsCredential = new Mock<EventHubTokenCredential>(mockCredential.Object);
var mockTokenProvider = new CbsTokenProvider(mockEventHubsCredential.Object, CancellationToken.None);
var mockTokenProvider = new CbsTokenProvider(mockEventHubsCredential.Object, TimeSpan.Zero, CancellationToken.None);
var mockScope = new MockConnectionScope(endpoint, endpoint, eventHub, mockEventHubsCredential.Object, transport, null, idleTimeout);

// This is brittle, but the AMQP library does not support mocking nor setting this directly.
Expand Down Expand Up @@ -2159,11 +2159,12 @@ public void CalculateLinkAuthorizationRefreshIntervalRespectsTheRefreshBuffer()
var currentTime = new DateTime(2015, 10, 27, 00, 00, 00);
var expireTime = currentTime.AddHours(1);
var buffer = GetAuthorizationRefreshBuffer();
var jitterBuffer = TimeSpan.FromSeconds(GetAuthorizationBaseJitterSeconds()).Add(TimeSpan.FromSeconds(5));
var calculatedRefresh = mockScope.InvokeCalculateLinkAuthorizationRefreshInterval(expireTime, currentTime);
var calculatedExpire = currentTime.Add(calculatedRefresh);

Assert.That(calculatedExpire, Is.LessThan(expireTime), "The refresh should be account for the buffer and be earlier than expiration.");
Assert.That(calculatedExpire, Is.EqualTo(expireTime.Subtract(buffer)).Within(TimeSpan.FromSeconds(2)), "The authorization buffer should have been used for buffering.");
Assert.That(calculatedExpire, Is.EqualTo(expireTime.Subtract(buffer)).Within(jitterBuffer), "The authorization buffer should have been used for buffering.");
}

/// <summary>
Expand All @@ -2180,10 +2181,12 @@ public void CalculateLinkAuthorizationRefreshIntervalRespectsTheMinimumDuration(
var mockScope = new MockConnectionScope(endPoint, endPoint, "test", credential.Object, EventHubsTransportType.AmqpTcp, null, idleTimeout);
var currentTime = new DateTime(2015, 10, 27, 00, 00, 00);
var minimumRefresh = GetMinimumAuthorizationRefresh();
var jitterBuffer = TimeSpan.FromSeconds(GetAuthorizationBaseJitterSeconds()).Add(TimeSpan.FromSeconds(5));
var expireTime = currentTime.Add(minimumRefresh.Subtract(TimeSpan.FromMilliseconds(500)));
var calculatedRefresh = mockScope.InvokeCalculateLinkAuthorizationRefreshInterval(expireTime, currentTime);

Assert.That(calculatedRefresh, Is.EqualTo(minimumRefresh), "The minimum refresh duration should have been used.");
Assert.That(calculatedRefresh, Is.GreaterThanOrEqualTo(minimumRefresh), "The minimum refresh duration should be violated.");
Assert.That(calculatedRefresh, Is.EqualTo(minimumRefresh).Within(jitterBuffer), "The minimum refresh duration should have been used.");
}

/// <summary>
Expand All @@ -2200,11 +2203,13 @@ public void CalculateLinkAuthorizationRefreshIntervalRespectsTheMaximumDuration(
var mockScope = new MockConnectionScope(endPoint, endPoint, "test", credential.Object, EventHubsTransportType.AmqpTcp, null, idleTimeout);
var currentTime = new DateTime(2015, 10, 27, 00, 00, 00);
var refreshBuffer = GetAuthorizationRefreshBuffer();
var jitterBuffer = TimeSpan.FromSeconds(GetAuthorizationBaseJitterSeconds()).Add(TimeSpan.FromSeconds(5));
var maximumRefresh = GetMaximumAuthorizationRefresh();
var expireTime = currentTime.Add(maximumRefresh.Add(refreshBuffer).Add(TimeSpan.FromMilliseconds(500)));
var calculatedRefresh = mockScope.InvokeCalculateLinkAuthorizationRefreshInterval(expireTime, currentTime);

Assert.That(calculatedRefresh, Is.EqualTo(maximumRefresh), "The maximum refresh duration should have been used.");
Assert.That(calculatedRefresh, Is.LessThanOrEqualTo(maximumRefresh), "The maximum refresh duration should not be exceeded.");
Assert.That(calculatedRefresh, Is.EqualTo(maximumRefresh).Within(jitterBuffer), "The maximum refresh duration should have been used.");
}

/// <summary>
Expand Down Expand Up @@ -2315,6 +2320,17 @@ private static TimeSpan GetMaximumAuthorizationRefresh() =>
.GetProperty("MaximumAuthorizationRefresh", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.GetProperty)
.GetValue(null);

/// <summary>
/// Gets the base time used to calculate random jitter for refreshing authorization,
/// using the private accessor.
/// </summary>
///
private static int GetAuthorizationBaseJitterSeconds() =>
(int)
typeof(AmqpConnectionScope)
.GetProperty("AuthorizationBaseJitterSeconds", BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.GetProperty)
.GetValue(null);

/// <summary>
/// Creates a set of dummy settings for testing purposes.
/// </summary>
Expand Down
Loading

0 comments on commit 774caf6

Please sign in to comment.