Skip to content

Commit

Permalink
Move caching stuff off the IAccessTokenProvider API into the factory (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
j3parker authored Nov 29, 2024
1 parent 234de0f commit 5adb1c8
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Net.Http;
using System.Security.Cryptography;
using System.Threading;
using D2L.Security.OAuth2.Caching;
using D2L.Security.OAuth2.Keys;
using D2L.Security.OAuth2.Keys.Default;
using D2L.Security.OAuth2.Keys.Development;
Expand Down Expand Up @@ -38,9 +39,9 @@ public static IAccessTokenProvider Create( HttpClient httpClient, String tokenPr
Uri authEndpoint = new Uri( tokenProvisioningEndpoint );
ITokenSigner tokenSigner = new TokenSigner( privateKeyProvider );
IAuthServiceClient authServiceClient = new AuthServiceClient( httpClient, authEndpoint );
INonCachingAccessTokenProvider noCacheTokenProvider = new AccessTokenProvider( tokenSigner, authServiceClient );
IAccessTokenProvider noCacheTokenProvider = new AccessTokenProvider( tokenSigner, authServiceClient );

return new CachedAccessTokenProvider( noCacheTokenProvider, authEndpoint, Timeout.InfiniteTimeSpan );
return new CachedAccessTokenProvider( new NullCache(), noCacheTokenProvider, authEndpoint, Timeout.InfiniteTimeSpan );
}

}
Expand Down
33 changes: 25 additions & 8 deletions src/D2L.Security.OAuth2/Provisioning/AccessTokenProviderFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Net.Http;
using D2L.Security.OAuth2.Caching;
using D2L.Security.OAuth2.Keys;
using D2L.Security.OAuth2.Provisioning.Default;

Expand All @@ -18,18 +19,34 @@ public static IAccessTokenProvider Create(
ITokenSigner tokenSigner,
HttpClient httpClient,
Uri authEndpoint,
TimeSpan tokenRefreshGracePeriod
TimeSpan tokenRefreshGracePeriod,
IAccessTokenProvider inner = null,
ICache cache = null
) {
if( inner != null && cache == null ) {
throw new InvalidOperationException( "If you provide an inner you need to also provide its cache" );
}

IAuthServiceClient authServiceClient = new AuthServiceClient(
httpClient,
authEndpoint
);
if( inner == null ) {
IAuthServiceClient authServiceClient = new AuthServiceClient(
httpClient,
authEndpoint
);

inner = new AccessTokenProvider( tokenSigner, authServiceClient );
}

INonCachingAccessTokenProvider accessTokenProvider =
new AccessTokenProvider( tokenSigner, authServiceClient );

return new CachedAccessTokenProvider( accessTokenProvider, authEndpoint, tokenRefreshGracePeriod );
if( cache != null ) {
return inner;
}

return new CachedAccessTokenProvider(
cache,
inner,
authEndpoint,
tokenRefreshGracePeriod
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace D2L.Security.OAuth2.Provisioning.Default {

internal sealed partial class AccessTokenProvider : INonCachingAccessTokenProvider {
internal sealed partial class AccessTokenProvider : IAccessTokenProvider {

private readonly IAuthServiceClient m_client;
private readonly ITokenSigner m_tokenSigner;
Expand All @@ -23,7 +23,7 @@ IAuthServiceClient authServiceClient
}

[GenerateSync]
async Task<IAccessToken> INonCachingAccessTokenProvider.ProvisionAccessTokenAsync(
async Task<IAccessToken> IAccessTokenProvider.ProvisionAccessTokenAsync(
IEnumerable<Claim> claimSet,
IEnumerable<Scope> scopes
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@

namespace D2L.Security.OAuth2.Provisioning.Default {
internal sealed partial class CachedAccessTokenProvider : IAccessTokenProvider {
private readonly INonCachingAccessTokenProvider m_accessTokenProvider;
private readonly ICache m_cache;
private readonly IAccessTokenProvider m_inner;
private readonly Uri m_authEndpoint;
private readonly TimeSpan m_tokenRefreshGracePeriod;
private readonly JwtSecurityTokenHandler m_tokenHandler;

public CachedAccessTokenProvider(
INonCachingAccessTokenProvider accessTokenProvider,
ICache cache,
IAccessTokenProvider inner,
Uri authEndpoint,
TimeSpan tokenRefreshGracePeriod
) {
m_accessTokenProvider = accessTokenProvider;
m_cache = cache;
m_inner = inner;
m_authEndpoint = authEndpoint;
m_tokenRefreshGracePeriod = tokenRefreshGracePeriod;

Expand All @@ -35,19 +38,15 @@ TimeSpan tokenRefreshGracePeriod
[GenerateSync]
async Task<IAccessToken> IAccessTokenProvider.ProvisionAccessTokenAsync(
IEnumerable<Claim> claims,
IEnumerable<Scope> scopes,
ICache cache
IEnumerable<Scope> scopes
) {
if( cache == null ) {
cache = new NullCache();
}

claims = claims.ToList();
scopes = scopes.ToList();

string cacheKey = TokenCacheKeyBuilder.BuildKey( m_authEndpoint, claims, scopes );

CacheResponse cacheResponse = await cache.GetAsync( cacheKey ).ConfigureAwait( false );
CacheResponse cacheResponse = await m_cache.GetAsync( cacheKey )
.ConfigureAwait( false );

if( cacheResponse.Success ) {
SecurityToken securityToken = m_tokenHandler.ReadToken( cacheResponse.Value );
Expand All @@ -56,12 +55,19 @@ ICache cache
}
}

IAccessToken token =
await m_accessTokenProvider.ProvisionAccessTokenAsync( claims, scopes ).ConfigureAwait( false );
IAccessToken token = await m_inner.ProvisionAccessTokenAsync(
claims,
scopes
).ConfigureAwait( false );

DateTime validTo = m_tokenHandler.ReadToken( token.Token ).ValidTo;

await cache.SetAsync( cacheKey, token.Token, validTo - DateTime.UtcNow ).ConfigureAwait( false );
await m_cache.SetAsync(
cacheKey,
token.Token,
expiry: validTo - DateTime.UtcNow
).ConfigureAwait( false );

return token;
}
}
Expand Down

This file was deleted.

5 changes: 1 addition & 4 deletions src/D2L.Security.OAuth2/Provisioning/IAccessTokenProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Security.Claims;
using System.Threading.Tasks;
using D2L.CodeStyle.Annotations;
using D2L.Security.OAuth2.Caching;
using D2L.Security.OAuth2.Scopes;

namespace D2L.Security.OAuth2.Provisioning {
Expand All @@ -16,15 +15,13 @@ public partial interface IAccessTokenProvider {
/// </summary>
/// <param name="claims">The set of claims to be included in the token.</param>
/// <param name="scopes">The set of scopes to be included in the token.</param>
/// <param name="cache">The provided <see cref="ICache"/> does not need to
/// check for token expiration or grace period because the
/// <see cref="IAccessTokenProvider"/> will handle it internally.</param>
/// <returns>An access token containing an expiry and the provided claims and scopes.</returns>
[GenerateSync]
Task<IAccessToken> ProvisionAccessTokenAsync(
IEnumerable<Claim> claims,
IEnumerable<Scope> scopes,
ICache cache = null
IEnumerable<Scope> scopes
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ internal sealed class CachedAccessTokenProviderTests {
private readonly Claim[] m_claims = { new Claim( "abc", "123" ), new Claim( "xyz", "789" ) };
private readonly Scope[] m_scopes = { new Scope( "a", "b", "c" ), new Scope( "7", "8", "9" ) };

private Mock<INonCachingAccessTokenProvider> m_accessTokenProviderMock;
private Mock<IAccessTokenProvider> m_accessTokenProviderMock;
private Mock<ICache> m_userTokenCacheMock;
private Mock<ICache> m_serviceTokenCacheMock;

[SetUp]
public void Setup() {
m_accessTokenProviderMock = new Mock<INonCachingAccessTokenProvider>( MockBehavior.Strict );
m_accessTokenProviderMock = new Mock<IAccessTokenProvider>( MockBehavior.Strict );
m_serviceTokenCacheMock = new Mock<ICache>( MockBehavior.Strict );
m_userTokenCacheMock = new Mock<ICache>( MockBehavior.Strict );
}
Expand All @@ -51,10 +51,10 @@ public async Task ProvisionAccessTokenAsync_NotCached_CallsThroughToAccessTokenP
m_serviceTokenCacheMock.Setup( x => x.SetAsync( It.IsAny<string>(), It.IsAny<string>(), It.IsAny<TimeSpan>() ) )
.Returns( Task.CompletedTask );

IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider();
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_serviceTokenCacheMock.Object );

IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes, m_serviceTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes ).ConfigureAwait( false );
Assert.AreEqual( accessToken.Token, token.Token );
}

Expand All @@ -64,10 +64,10 @@ public async Task ProvisionAccessTokenAsync_AlreadyCached_UsesCachedValueAndDoes
m_serviceTokenCacheMock.Setup( x => x.GetAsync( It.IsAny<string>() ) )
.Returns( Task.FromResult( new CacheResponse( true, BuildTestToken() ) ) );

IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider();
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_serviceTokenCacheMock.Object );

IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes, m_serviceTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes ).ConfigureAwait( false );
Assert.NotNull( token );
}

Expand All @@ -85,10 +85,10 @@ public async Task ProvisionAccessTokenAsync_TokenIsAlreadyCachedButIsWithinGrace
.Returns( Task.FromResult( accessToken ) );

const int gracePeriodThatIsBiggerThanTimeToExpiry = TOKEN_EXPIRY_IN_SECONDS + 60;
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( gracePeriodThatIsBiggerThanTimeToExpiry );
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_serviceTokenCacheMock.Object, gracePeriodThatIsBiggerThanTimeToExpiry );

IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes, m_serviceTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes ).ConfigureAwait( false );
Assert.NotNull( token );
}

Expand All @@ -98,11 +98,11 @@ public async Task ProvisionAccessTokenAsync_UserClaimProvided_UserCacheUsed() {
m_userTokenCacheMock.Setup( x => x.GetAsync( It.IsAny<string>() ) )
.Returns( Task.FromResult( new CacheResponse( true, BuildTestToken( specifyUserClaim: true ) ) ) );

IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider();
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_userTokenCacheMock.Object );

Claim userClaim = new Claim( Constants.Claims.USER_ID, "user" );
IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( new[] { userClaim }, m_scopes, m_userTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( new[] { userClaim }, m_scopes ).ConfigureAwait( false );
Assert.NotNull( token );
}

Expand All @@ -112,10 +112,10 @@ public async Task ProvisionAccessTokenAsync_ServiceClaimProvided_ServiceCacheUse
m_serviceTokenCacheMock.Setup( x => x.GetAsync( It.IsAny<string>() ) )
.Returns( Task.FromResult( new CacheResponse( true, BuildTestToken( specifyUserClaim: false ) ) ) );

IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider();
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_serviceTokenCacheMock.Object );

IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes, m_serviceTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( m_claims, m_scopes ).ConfigureAwait( false );
Assert.NotNull( token );
}

Expand All @@ -131,14 +131,15 @@ public async Task ProvisionAccessTokenAsync_CallPassThroughOverload_CallsOtherOv
new Claim( Constants.Claims.ISSUER, "TheIssuer" )
};

IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider();
IAccessTokenProvider cachedAccessTokenProvider = GetCachedAccessTokenProvider( m_serviceTokenCacheMock.Object );
IAccessToken token =
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( claimSet, Enumerable.Empty<Scope>(), m_serviceTokenCacheMock.Object ).ConfigureAwait( false );
await cachedAccessTokenProvider.ProvisionAccessTokenAsync( claimSet, Enumerable.Empty<Scope>() ).ConfigureAwait( false );
Assert.NotNull( token );
}

private IAccessTokenProvider GetCachedAccessTokenProvider( int tokenRefreshGracePeriod = 120 ) {
private IAccessTokenProvider GetCachedAccessTokenProvider( ICache cache, int tokenRefreshGracePeriod = 120 ) {
return new CachedAccessTokenProvider(
cache,
m_accessTokenProviderMock.Object,
new Uri( "https://example.com" ),
TimeSpan.FromSeconds( tokenRefreshGracePeriod )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ private static class TestData {

private IPublicKeyDataProvider m_publicKeyDataProvider;
private ITokenSigner m_tokenSigner;
private INonCachingAccessTokenProvider m_accessTokenProvider;
private IAccessTokenProvider m_accessTokenProvider;
private JwtSecurityToken m_actualAssertion;

[SetUp]
Expand Down

0 comments on commit 5adb1c8

Please sign in to comment.